Compare commits
141 Commits
v0.4.3
...
v0.5.0.pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50eed24d25 | ||
|
|
e38042d4af | ||
|
|
33e3b37242 | ||
|
|
1696efe6c9 | ||
|
|
6b0511a57b | ||
|
|
a8fda4f661 | ||
|
|
30299a41fa | ||
|
|
85657b5607 | ||
|
|
0ce7b952f8 | ||
|
|
39873476f8 | ||
|
|
03dccc886e | ||
|
|
a65634d3ae | ||
|
|
80aa7e91fc | ||
|
|
bd43973522 | ||
|
|
23ec72fa03 | ||
|
|
c2637a613b | ||
|
|
88407532e7 | ||
|
|
916d219d62 | ||
|
|
ea3890a5f0 | ||
|
|
2135cacb45 | ||
|
|
7d19de2e9c | ||
|
|
94a07bbdd8 | ||
|
|
b8d4dfff9c | ||
|
|
622d45128c | ||
|
|
51602eefd3 | ||
|
|
5cc50a531f | ||
|
|
5985e3427d | ||
|
|
8b82a89997 | ||
|
|
c3c2903e72 | ||
|
|
1a8bfd92d5 | ||
|
|
847cdcca1c | ||
|
|
e3c12bf6d2 | ||
|
|
3dd6853bc8 | ||
|
|
8f89d72090 | ||
|
|
99dac099ab | ||
|
|
c4bd03c7c5 | ||
|
|
dcbf4286af | ||
|
|
00e6a2dc53 | ||
|
|
2e02311a1b | ||
|
|
89ec06c33b | ||
|
|
9fde251bf0 | ||
|
|
4c2ffb28ff | ||
|
|
246598a6b1 | ||
|
|
8bab4959be | ||
|
|
3c4cebf751 | ||
|
|
d8f31f2f8b | ||
|
|
640052b069 | ||
|
|
351d5e7b82 | ||
|
|
a008629807 | ||
|
|
76477a93b7 | ||
|
|
77c87beb06 | ||
|
|
114332b88e | ||
|
|
cb77ad836f | ||
|
|
856c990041 | ||
|
|
c5602f0baa | ||
|
|
f7f9c5f97b | ||
|
|
2c0d933594 | ||
|
|
774d1035e4 | ||
|
|
6b29d6fe70 | ||
|
|
0bfa1c4f13 | ||
|
|
c81da5f56d | ||
|
|
68bc81703e | ||
|
|
5884c2b454 | ||
|
|
45f92c00cf | ||
|
|
5467ac3196 | ||
|
|
5d7e3d0176 | ||
|
|
0373e1837e | ||
|
|
c09dade2a2 | ||
|
|
8ea5e44a43 | ||
|
|
9fb900f90c | ||
|
|
c96fc06747 | ||
|
|
b3376e5c76 | ||
|
|
e69ded7d1c | ||
|
|
767c727a81 | ||
|
|
6840a71610 | ||
|
|
7a9cb294ae | ||
|
|
ca3ea51bde | ||
|
|
dc49fb892c | ||
|
|
18a277b52d | ||
|
|
8d75fe48ca | ||
|
|
388596c914 | ||
|
|
baa15a9ec3 | ||
|
|
15063741e3 | ||
|
|
ccdc490dda | ||
|
|
a31cab7556 | ||
|
|
828da0d44e | ||
|
|
abe855d637 | ||
|
|
4efff036f0 | ||
|
|
89c920785f | ||
|
|
7b0a0dfb22 | ||
|
|
3a6ae1d33c | ||
|
|
8f1729b829 | ||
|
|
6a7c7711a2 | ||
|
|
0f83ddd4d7 | ||
|
|
065aff6c16 | ||
|
|
3d33e372a1 | ||
|
|
faf71bcd4b | ||
|
|
f270a39537 | ||
|
|
51a08e7d8f | ||
|
|
eb8fcd2666 | ||
|
|
5563a4dea8 | ||
|
|
ccd4f129e8 | ||
|
|
02cc3b51a7 | ||
|
|
d5b1eb081e | ||
|
|
f0a500545f | ||
|
|
c65146e75e | ||
|
|
41ca62cf03 | ||
|
|
974fc9b845 | ||
|
|
fee4dcc33a | ||
|
|
650a4cc55e | ||
|
|
9ca62d8668 | ||
|
|
45c35f0d58 | ||
|
|
9ba093b4f4 | ||
|
|
27208be66e | ||
|
|
87d5abef75 | ||
|
|
ec784b2526 | ||
|
|
a58f24e590 | ||
|
|
f42a006b15 | ||
|
|
3a434b07ed | ||
|
|
bd0e7802e0 | ||
|
|
06b2550cbb | ||
|
|
f775a07e30 | ||
|
|
4f0d17c05c | ||
|
|
10c38e3e46 | ||
|
|
cafb8e06c5 | ||
|
|
cbb2f59cc8 | ||
|
|
0ab278ca31 | ||
|
|
7a64d24aad | ||
|
|
dfbe60dc62 | ||
|
|
a66cf40b20 | ||
|
|
f790ad3c50 | ||
|
|
ed59a7ed23 | ||
|
|
044793d8df | ||
|
|
c2d6d2f960 | ||
|
|
8279078e21 | ||
|
|
b9c0605a8e | ||
|
|
37464a0f74 | ||
|
|
c354072828 | ||
|
|
f081c3ce4b | ||
|
|
260d119e86 | ||
|
|
a360ff80bb |
26
.buildkite/nightly-benchmarks/kickoff-pipeline.sh
Executable file
26
.buildkite/nightly-benchmarks/kickoff-pipeline.sh
Executable file
@@ -0,0 +1,26 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Install system packages
|
||||||
|
apt update
|
||||||
|
apt install -y curl jq
|
||||||
|
|
||||||
|
# Install minijinja for templating
|
||||||
|
curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh
|
||||||
|
source $HOME/.cargo/env
|
||||||
|
|
||||||
|
# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq
|
||||||
|
if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then
|
||||||
|
PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name')
|
||||||
|
|
||||||
|
if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then
|
||||||
|
echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks."
|
||||||
|
else
|
||||||
|
echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Upload sample.yaml
|
||||||
|
buildkite-agent pipeline upload .buildkite/nightly-benchmarks/sample.yaml
|
||||||
39
.buildkite/nightly-benchmarks/sample.yaml
Normal file
39
.buildkite/nightly-benchmarks/sample.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
steps:
|
||||||
|
# NOTE(simon): You can create separate blocks for different jobs
|
||||||
|
- label: "A100: NVIDIA SMI"
|
||||||
|
agents:
|
||||||
|
queue: A100
|
||||||
|
plugins:
|
||||||
|
- kubernetes:
|
||||||
|
podSpec:
|
||||||
|
containers:
|
||||||
|
# - image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT
|
||||||
|
# TODO(simon): check latest main branch or use the PR image.
|
||||||
|
- image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
|
||||||
|
command:
|
||||||
|
- bash -c 'nvidia-smi && nvidia-smi topo -m && pwd && ls'
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
nvidia.com/gpu: 8
|
||||||
|
volumeMounts:
|
||||||
|
- name: devshm
|
||||||
|
mountPath: /dev/shm
|
||||||
|
nodeSelector:
|
||||||
|
nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
|
||||||
|
volumes:
|
||||||
|
- name: devshm
|
||||||
|
emptyDir:
|
||||||
|
medium: Memory
|
||||||
|
# TODO(simon): bring H100 online
|
||||||
|
# - label: "H100: NVIDIA SMI"
|
||||||
|
# agents:
|
||||||
|
# queue: H100
|
||||||
|
# plugins:
|
||||||
|
# - docker#v5.11.0:
|
||||||
|
# image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
|
||||||
|
# command:
|
||||||
|
# - bash -c 'nvidia-smi && nvidia-smi topo -m'
|
||||||
|
# propagate-environment: true
|
||||||
|
# ipc: host
|
||||||
|
# gpus: all
|
||||||
|
|
||||||
@@ -50,16 +50,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md
|
|||||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||||
echo "" >> benchmark_results.md
|
echo "" >> benchmark_results.md
|
||||||
echo '```' >> benchmark_results.md
|
echo '```' >> benchmark_results.md
|
||||||
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
|
tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines
|
||||||
echo '```' >> benchmark_results.md
|
echo '```' >> benchmark_results.md
|
||||||
|
|
||||||
# if the agent binary is not found, skip uploading the results, exit 0
|
# if the agent binary is not found, skip uploading the results, exit 0
|
||||||
if [ ! -f /workspace/buildkite-agent ]; then
|
if [ ! -f /usr/bin/buildkite-agent ]; then
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# upload the results to buildkite
|
# upload the results to buildkite
|
||||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||||
|
|
||||||
# exit with the exit code of the benchmarks
|
# exit with the exit code of the benchmarks
|
||||||
if [ $bench_latency_exit_code -ne 0 ]; then
|
if [ $bench_latency_exit_code -ne 0 ]; then
|
||||||
@@ -75,4 +75,4 @@ if [ $bench_serving_exit_code -ne 0 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
rm ShareGPT_V3_unfiltered_cleaned_split.json
|
rm ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
/workspace/buildkite-agent artifact upload "*.json"
|
buildkite-agent artifact upload "*.json"
|
||||||
|
|||||||
@@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
|
|||||||
trap remove_docker_container EXIT
|
trap remove_docker_container EXIT
|
||||||
remove_docker_container
|
remove_docker_container
|
||||||
|
|
||||||
# Run the image and launch offline inference
|
# Run the image
|
||||||
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py
|
docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
|
||||||
|
|
||||||
|
# offline inference
|
||||||
|
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
|
||||||
|
|
||||||
|
# Run basic model test
|
||||||
|
docker exec cpu-test bash -c "cd tests;
|
||||||
|
pip install pytest Pillow protobuf
|
||||||
|
bash ../.buildkite/download-images.sh
|
||||||
|
cd ../
|
||||||
|
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
@@ -45,7 +46,9 @@ steps:
|
|||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||||
|
|
||||||
- label: Distributed Tests (Multiple Groups)
|
- label: Distributed Tests (Multiple Groups)
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
@@ -62,7 +65,6 @@ steps:
|
|||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s test_inputs.py
|
|
||||||
- pytest -v -s entrypoints -m llm
|
- pytest -v -s entrypoints -m llm
|
||||||
- pytest -v -s entrypoints -m openai
|
- pytest -v -s entrypoints -m openai
|
||||||
|
|
||||||
@@ -79,6 +81,13 @@ steps:
|
|||||||
- python3 llava_example.py
|
- python3 llava_example.py
|
||||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
|
|
||||||
|
- label: Inputs Test
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
commands:
|
||||||
|
- bash ../.buildkite/download-images.sh
|
||||||
|
- pytest -v -s test_inputs.py
|
||||||
|
- pytest -v -s multimodal
|
||||||
|
|
||||||
- label: Kernels Test %N
|
- label: Kernels Test %N
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
@@ -87,14 +96,13 @@ steps:
|
|||||||
- label: Models Test
|
- label: Models Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
- pytest -v -s models -m \"not llava\"
|
||||||
- pytest -v -s models --ignore=models/test_llava.py
|
|
||||||
|
|
||||||
- label: Llava Test
|
- label: Llava Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
- bash ../.buildkite/download-images.sh
|
||||||
- pytest -v -s models/test_llava.py
|
- pytest -v -s models -m llava
|
||||||
|
|
||||||
- label: Prefix Caching Test
|
- label: Prefix Caching Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
@@ -118,7 +126,10 @@ steps:
|
|||||||
|
|
||||||
- label: Speculative decoding tests
|
- label: Speculative decoding tests
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s spec_decode
|
commands:
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/5152
|
||||||
|
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
|
- pytest -v -s spec_decode
|
||||||
|
|
||||||
- label: LoRA Test %N
|
- label: LoRA Test %N
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
@@ -130,14 +141,7 @@ steps:
|
|||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||||
commands:
|
commands:
|
||||||
# Temporarily run this way because we cannot clean up GPU mem usage
|
- pytest -v -s -x lora/test_long_context.py
|
||||||
# for multi GPU tests.
|
|
||||||
# TODO(sang): Fix it.
|
|
||||||
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
|
|
||||||
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
|
|
||||||
- pytest -v -s lora/test_long_context.py::test_self_consistency
|
|
||||||
- pytest -v -s lora/test_long_context.py::test_quality
|
|
||||||
- pytest -v -s lora/test_long_context.py::test_max_len
|
|
||||||
|
|
||||||
- label: Tensorizer Test
|
- label: Tensorizer Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
|
|||||||
92
.buildkite/test-template-aws.j2
Normal file
92
.buildkite/test-template-aws.j2
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
|
||||||
|
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- label: ":docker: build image"
|
||||||
|
agents:
|
||||||
|
queue: cpu_queue
|
||||||
|
commands:
|
||||||
|
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||||
|
- "docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
|
- "docker push {{ docker_image }}"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- exit_status: -10 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- wait
|
||||||
|
|
||||||
|
- group: "AMD Tests"
|
||||||
|
depends_on: ~
|
||||||
|
steps:
|
||||||
|
{% for step in steps %}
|
||||||
|
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
|
||||||
|
- label: "AMD: {{ step.label }}"
|
||||||
|
agents:
|
||||||
|
queue: amd
|
||||||
|
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
|
soft_fail: true
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
- label: "Neuron Test"
|
||||||
|
depends_on: ~
|
||||||
|
agents:
|
||||||
|
queue: neuron
|
||||||
|
command: bash .buildkite/run-neuron-test.sh
|
||||||
|
soft_fail: false
|
||||||
|
|
||||||
|
- label: "Intel Test"
|
||||||
|
depends_on: ~
|
||||||
|
agents:
|
||||||
|
queue: intel
|
||||||
|
command: bash .buildkite/run-cpu-test.sh
|
||||||
|
|
||||||
|
{% for step in steps %}
|
||||||
|
- label: "{{ step.label }}"
|
||||||
|
agents:
|
||||||
|
{% if step.label == "Documentation Build" %}
|
||||||
|
queue: small_cpu_queue
|
||||||
|
{% elif step.no_gpu %}
|
||||||
|
queue: cpu_queue
|
||||||
|
{% elif step.num_gpus == 2 or step.num_gpus == 4 %}
|
||||||
|
queue: gpu_4_queue
|
||||||
|
{% else %}
|
||||||
|
queue: gpu_1_queue
|
||||||
|
{% endif %}
|
||||||
|
soft_fail: {{ step.soft_fail or false }}
|
||||||
|
{% if step.parallelism %}
|
||||||
|
parallelism: {{ step.parallelism }}
|
||||||
|
{% endif %}
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- exit_status: -10 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
plugins:
|
||||||
|
- docker#v5.2.0:
|
||||||
|
image: {{ docker_image }}
|
||||||
|
always-pull: true
|
||||||
|
propagate-environment: true
|
||||||
|
{% if not step.no_gpu %}
|
||||||
|
gpus: all
|
||||||
|
{% endif %}
|
||||||
|
{% if step.label == "Benchmarks" %}
|
||||||
|
mount-buildkite-agent: true
|
||||||
|
{% endif %}
|
||||||
|
command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"]
|
||||||
|
environment:
|
||||||
|
- VLLM_USAGE_SOURCE=ci-test
|
||||||
|
- HF_TOKEN
|
||||||
|
{% if step.label == "Speculative decoding tests" %}
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
|
{% endif %}
|
||||||
|
volumes:
|
||||||
|
- /dev/shm:/dev/shm
|
||||||
|
{% endfor %}
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- label: ":docker: build image"
|
- label: ":docker: build image"
|
||||||
commands:
|
commands:
|
||||||
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
- "docker push {{ docker_image }}"
|
- "docker push {{ docker_image }}"
|
||||||
env:
|
env:
|
||||||
@@ -28,6 +28,7 @@ steps:
|
|||||||
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
|
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
soft_fail: true
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
@@ -36,10 +37,12 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: neuron
|
queue: neuron
|
||||||
command: bash .buildkite/run-neuron-test.sh
|
command: bash .buildkite/run-neuron-test.sh
|
||||||
soft_fail: true
|
soft_fail: false
|
||||||
|
|
||||||
- label: "Intel Test"
|
- label: "Intel Test"
|
||||||
depends_on: ~
|
depends_on: ~
|
||||||
|
agents:
|
||||||
|
queue: intel
|
||||||
command: bash .buildkite/run-cpu-test.sh
|
command: bash .buildkite/run-cpu-test.sh
|
||||||
|
|
||||||
{% for step in steps %}
|
{% for step in steps %}
|
||||||
|
|||||||
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@@ -37,6 +37,7 @@ jobs:
|
|||||||
mypy vllm/distributed --config-file pyproject.toml
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
mypy vllm/entrypoints --config-file pyproject.toml
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
mypy vllm/executor --config-file pyproject.toml
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
|
mypy vllm/multimodal --config-file pyproject.toml
|
||||||
mypy vllm/usage --config-file pyproject.toml
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
mypy vllm/*.py --config-file pyproject.toml
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
|
|||||||
2
.github/workflows/ruff.yml
vendored
2
.github/workflows/ruff.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2
|
pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2
|
||||||
- name: Analysing the code with ruff
|
- name: Analysing the code with ruff
|
||||||
run: |
|
run: |
|
||||||
ruff .
|
ruff .
|
||||||
|
|||||||
@@ -66,19 +66,6 @@ endif()
|
|||||||
#
|
#
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
#
|
|
||||||
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
|
|
||||||
# `libtorch_python.so` for linking against an extension. Torch's cmake
|
|
||||||
# configuration does not include this library (presumably since the cmake
|
|
||||||
# config is used for standalone C++ binaries that link against torch).
|
|
||||||
# The `libtorch_python.so` library defines some of the glue code between
|
|
||||||
# torch/python via pybind and is required by VLLM extensions for this
|
|
||||||
# reason. So, add it by manually with `find_library` using torch's
|
|
||||||
# installed library path.
|
|
||||||
#
|
|
||||||
find_library(torch_python_LIBRARY torch_python PATHS
|
|
||||||
"${TORCH_INSTALL_PREFIX}/lib")
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||||
#
|
#
|
||||||
@@ -171,7 +158,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
"csrc/moe_align_block_size_kernels.cu"
|
||||||
"csrc/pybind.cpp")
|
"csrc/torch_bindings.cpp")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
@@ -192,9 +179,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||||
|
|
||||||
#
|
#
|
||||||
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
||||||
@@ -202,7 +189,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||||
set_source_files_properties(
|
set_source_files_properties(
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||||
PROPERTIES
|
PROPERTIES
|
||||||
COMPILE_FLAGS
|
COMPILE_FLAGS
|
||||||
"-gencode arch=compute_90a,code=sm_90a")
|
"-gencode arch=compute_90a,code=sm_90a")
|
||||||
@@ -218,6 +205,7 @@ define_gpu_extension_target(
|
|||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
#
|
#
|
||||||
@@ -225,7 +213,7 @@ define_gpu_extension_target(
|
|||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_MOE_EXT_SRC
|
set(VLLM_MOE_EXT_SRC
|
||||||
"csrc/moe/moe_ops.cpp"
|
"csrc/moe/torch_bindings.cpp"
|
||||||
"csrc/moe/topk_softmax_kernels.cu")
|
"csrc/moe/topk_softmax_kernels.cu")
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
@@ -235,6 +223,7 @@ define_gpu_extension_target(
|
|||||||
SOURCES ${VLLM_MOE_EXT_SRC}
|
SOURCES ${VLLM_MOE_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
#
|
#
|
||||||
@@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC
|
|||||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||||
"csrc/punica/punica_ops.cu"
|
"csrc/punica/punica_ops.cu"
|
||||||
"csrc/punica/punica_pybind.cpp")
|
"csrc/punica/torch_bindings.cpp")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Copy GPU compilation flags+update for punica
|
# Copy GPU compilation flags+update for punica
|
||||||
@@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
|
|||||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
else()
|
else()
|
||||||
message(WARNING "Unable to create _punica_C target because none of the "
|
message(WARNING "Unable to create _punica_C target because none of the "
|
||||||
@@ -311,6 +301,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
add_dependencies(default _C)
|
add_dependencies(default _C)
|
||||||
|
|
||||||
|
message(STATUS "Enabling moe extension.")
|
||||||
|
add_dependencies(default _moe_C)
|
||||||
|
|
||||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
||||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
||||||
# there are supported target arches.
|
# there are supported target arches.
|
||||||
@@ -320,8 +313,3 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
add_dependencies(default _punica_C)
|
add_dependencies(default _punica_C)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|
||||||
message(STATUS "Enabling moe extension.")
|
|
||||||
add_dependencies(default _moe_C)
|
|
||||||
endif()
|
|
||||||
|
|||||||
24
Dockerfile
24
Dockerfile
@@ -10,7 +10,7 @@
|
|||||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
|
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y python3-pip git
|
&& apt-get install -y python3-pip git curl sudo
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
@@ -27,6 +27,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
pip install -r requirements-cuda.txt
|
pip install -r requirements-cuda.txt
|
||||||
|
|
||||||
# install development dependencies
|
# install development dependencies
|
||||||
|
COPY requirements-lint.txt requirements-lint.txt
|
||||||
|
COPY requirements-test.txt requirements-test.txt
|
||||||
COPY requirements-dev.txt requirements-dev.txt
|
COPY requirements-dev.txt requirements-dev.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
@@ -70,10 +72,28 @@ ENV NVCC_THREADS=$nvcc_threads
|
|||||||
# make sure punica kernels are built (for LoRA)
|
# make sure punica kernels are built (for LoRA)
|
||||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
|
||||||
|
ARG USE_SCCACHE
|
||||||
|
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||||
|
echo "Installing sccache..." \
|
||||||
|
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
|
||||||
|
&& tar -xzf sccache.tar.gz \
|
||||||
|
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||||
|
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||||
|
&& export SCCACHE_BUCKET=vllm-build-sccache \
|
||||||
|
&& export SCCACHE_REGION=us-west-2 \
|
||||||
|
&& sccache --show-stats \
|
||||||
|
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||||
|
&& sccache --show-stats; \
|
||||||
|
fi
|
||||||
|
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
ENV CCACHE_DIR=/root/.cache/ccache
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 setup.py bdist_wheel --dist-dir=dist
|
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||||
|
python3 setup.py bdist_wheel --dist-dir=dist; \
|
||||||
|
fi
|
||||||
|
|
||||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
||||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
||||||
|
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04 AS cpu-test-1
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
|
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
|
||||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||||
|
|
||||||
|
RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc
|
||||||
|
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
RUN pip install --upgrade pip \
|
RUN pip install --upgrade pip \
|
||||||
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
|
||||||
|
|
||||||
|
FROM cpu-test-1 AS build
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
@@ -19,4 +25,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
|||||||
|
|
||||||
WORKDIR /workspace/
|
WORKDIR /workspace/
|
||||||
|
|
||||||
|
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
|
|||||||
RUN cd /app/vllm \
|
RUN cd /app/vllm \
|
||||||
&& python3 -m pip install -U -r requirements-neuron.txt
|
&& python3 -m pip install -U -r requirements-neuron.txt
|
||||||
|
|
||||||
ENV VLLM_BUILD_WITH_NEURON 1
|
ENV VLLM_TARGET_DEVICE neuron
|
||||||
RUN cd /app/vllm \
|
RUN cd /app/vllm \
|
||||||
&& pip install -e . \
|
&& pip install -e . \
|
||||||
&& cd ..
|
&& cd ..
|
||||||
|
|||||||
@@ -106,8 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
pip install -U -r requirements-rocm.txt \
|
pip install -U -r requirements-rocm.txt \
|
||||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
||||||
&& python3 setup.py install \
|
&& python3 setup.py install \
|
||||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
|
||||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
|
||||||
|
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
|
||||||
&& cd ..
|
&& cd ..
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
19
Dockerfile.tpu
Normal file
19
Dockerfile.tpu
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
ARG NIGHTLY_DATE="20240601"
|
||||||
|
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||||
|
|
||||||
|
FROM $BASE_IMAGE
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
COPY . /workspace/vllm
|
||||||
|
|
||||||
|
ENV VLLM_TARGET_DEVICE="tpu"
|
||||||
|
# Install aiohttp separately to avoid build errors.
|
||||||
|
RUN pip install aiohttp
|
||||||
|
# Install the TPU and Pallas dependencies.
|
||||||
|
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
|
|
||||||
|
# Build vLLM.
|
||||||
|
RUN cd /workspace/vllm && python setup.py develop
|
||||||
|
|
||||||
|
CMD ["/bin/bash"]
|
||||||
14
README.md
14
README.md
@@ -16,16 +16,17 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
|
**Ray Summit CPF is Open (June 4th to June 20th)!**
|
||||||
|
|
||||||
We are thrilled to announce our fourth vLLM Meetup!
|
There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year!
|
||||||
The vLLM team will share recent updates and roadmap.
|
If you have cool projects related to vLLM or LLM inference, we would love to see your proposals.
|
||||||
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
|
This will be a great chance for everyone in the community to get together and learn.
|
||||||
Please register [here](https://lu.ma/agivllm) and join us!
|
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
- [2024/01] Added ROCm 6.0 support to vLLM.
|
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||||
@@ -58,7 +59,7 @@ vLLM is flexible and easy to use with:
|
|||||||
- Tensor parallelism support for distributed inference
|
- Tensor parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs and AMD GPUs
|
- Support NVIDIA GPUs, AMD GPUs, and Intel CPUs
|
||||||
- (Experimental) Prefix caching support
|
- (Experimental) Prefix caching support
|
||||||
- (Experimental) Multi-lora support
|
- (Experimental) Multi-lora support
|
||||||
|
|
||||||
@@ -107,6 +108,7 @@ vLLM is a community project. Our compute resources for development and testing a
|
|||||||
- Replicate
|
- Replicate
|
||||||
- Roblox
|
- Roblox
|
||||||
- RunPod
|
- RunPod
|
||||||
|
- Sequoia Capital
|
||||||
- Trainy
|
- Trainy
|
||||||
- UC Berkeley
|
- UC Berkeley
|
||||||
- UC San Diego
|
- UC San Diego
|
||||||
|
|||||||
@@ -68,9 +68,13 @@ async def async_request_tgi(
|
|||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
#NOTE: Sometimes TGI returns a ping response without
|
||||||
"data:")
|
# any data, we should skip it.
|
||||||
|
if chunk_bytes.startswith(":"):
|
||||||
|
continue
|
||||||
|
chunk = remove_prefix(chunk_bytes, "data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
|
|||||||
@@ -36,7 +36,8 @@ def main(args: argparse.Namespace):
|
|||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
download_dir=args.download_dir,
|
download_dir=args.download_dir,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
gpu_memory_utilization=args.gpu_memory_utilization)
|
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||||
|
distributed_executor_backend=args.distributed_executor_backend)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
@@ -188,7 +189,7 @@ if __name__ == '__main__':
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda", "cpu"],
|
choices=["cuda", "cpu", "tpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
@@ -221,5 +222,12 @@ if __name__ == '__main__':
|
|||||||
help='the fraction of GPU memory to be used for '
|
help='the fraction of GPU memory to be used for '
|
||||||
'the model executor, which can range from 0 to 1.'
|
'the model executor, which can range from 0 to 1.'
|
||||||
'If unspecified, will use the default value of 0.9.')
|
'If unspecified, will use the default value of 0.9.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--distributed-executor-backend',
|
||||||
|
choices=['ray', 'mp'],
|
||||||
|
default=None,
|
||||||
|
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||||
|
'is used, will be automatically set to "ray" if installed '
|
||||||
|
'or "mp" (multiprocessing) otherwise.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -56,6 +56,9 @@ class BenchmarkMetrics:
|
|||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
p99_tpot_ms: float
|
p99_tpot_ms: float
|
||||||
|
mean_itl_ms: float
|
||||||
|
median_itl_ms: float
|
||||||
|
p99_itl_ms: float
|
||||||
|
|
||||||
|
|
||||||
def sample_sharegpt_requests(
|
def sample_sharegpt_requests(
|
||||||
@@ -200,16 +203,24 @@ def calculate_metrics(
|
|||||||
actual_output_lens = []
|
actual_output_lens = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
completed = 0
|
completed = 0
|
||||||
|
itls = []
|
||||||
tpots = []
|
tpots = []
|
||||||
ttfts = []
|
ttfts = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
output_len = len(tokenizer(outputs[i].generated_text).input_ids)
|
# We use the tokenizer to count the number of output tokens for all
|
||||||
|
# serving backends instead of looking at len(outputs[i].itl) since
|
||||||
|
# multiple output tokens may be bundled together
|
||||||
|
# Note: this may inflate the output token count slightly
|
||||||
|
output_len = len(
|
||||||
|
tokenizer(outputs[i].generated_text,
|
||||||
|
add_special_tokens=False).input_ids)
|
||||||
actual_output_lens.append(output_len)
|
actual_output_lens.append(output_len)
|
||||||
total_input += input_requests[i][1]
|
total_input += input_requests[i][1]
|
||||||
if output_len > 1:
|
if output_len > 1:
|
||||||
tpots.append(
|
tpots.append(
|
||||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||||
|
itls += outputs[i].itl
|
||||||
ttfts.append(outputs[i].ttft)
|
ttfts.append(outputs[i].ttft)
|
||||||
completed += 1
|
completed += 1
|
||||||
else:
|
else:
|
||||||
@@ -234,6 +245,9 @@ def calculate_metrics(
|
|||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||||
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@@ -333,6 +347,10 @@ async def benchmark(
|
|||||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||||
metrics.median_tpot_ms))
|
metrics.median_tpot_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||||
|
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@@ -349,6 +367,9 @@ async def benchmark(
|
|||||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||||
"median_tpot_ms": metrics.median_tpot_ms,
|
"median_tpot_ms": metrics.median_tpot_ms,
|
||||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||||
|
"mean_itl_ms": metrics.mean_itl_ms,
|
||||||
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
|
"p99_itl_ms": metrics.p99_itl_ms,
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
"output_lens": actual_output_lens,
|
"output_lens": actual_output_lens,
|
||||||
"ttfts": [output.ttft for output in outputs],
|
"ttfts": [output.ttft for output in outputs],
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ def run_vllm(
|
|||||||
enable_prefix_caching: bool,
|
enable_prefix_caching: bool,
|
||||||
enable_chunked_prefill: bool,
|
enable_chunked_prefill: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
|
distributed_executor_backend: Optional[str],
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
@@ -100,6 +101,7 @@ def run_vllm(
|
|||||||
download_dir=download_dir,
|
download_dir=download_dir,
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@@ -225,8 +227,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.enforce_eager, args.kv_cache_dtype,
|
args.enforce_eager, args.kv_cache_dtype,
|
||||||
args.quantization_param_path, args.device,
|
args.quantization_param_path, args.device,
|
||||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||||
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||||
args.download_dir)
|
args.gpu_memory_utilization, args.download_dir)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@@ -344,7 +346,7 @@ if __name__ == "__main__":
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda", "cpu"],
|
choices=["cuda", "cpu", "tpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
@@ -368,6 +370,13 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Path to save the throughput results in JSON format.')
|
help='Path to save the throughput results in JSON format.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--distributed-executor-backend',
|
||||||
|
choices=['ray', 'mp'],
|
||||||
|
default=None,
|
||||||
|
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||||
|
'is used, will be automatically set to "ray" if installed '
|
||||||
|
'or "mp" (multiprocessing) otherwise.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|||||||
348
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
348
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from typing import Callable, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.tensor) -> torch.tensor:
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(
|
||||||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def to_int8(tensor: torch.tensor) -> torch.tensor:
|
||||||
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
|
k: int) -> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return to_int8(a), to_int8(b)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return to_fp8(a), to_fp8(b)
|
||||||
|
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
|
||||||
|
# impl
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return torch.mm(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return torch._scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
||||||
|
scale_a: torch.tensor, scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return torch._scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
use_fast_accum=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# bench
|
||||||
|
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
|
||||||
|
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"a": a,
|
||||||
|
"b": b,
|
||||||
|
"scale_a": scale_a,
|
||||||
|
"scale_b": scale_b,
|
||||||
|
"out_dtype": out_dtype,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.int8
|
||||||
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
# pytorch impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
|
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||||
|
torch.bfloat16, label, sub_label, pytorch_i8_impl,
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||||
|
|
||||||
|
# cutlass impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||||
|
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm"))
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.float8_e4m3fn
|
||||||
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
|
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
|
pytorch_fp8_impl_fast_accum,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||||
|
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||||
|
pytorch_fp8_impl_fast_accum,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||||
|
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||||
|
# cutlass impl: fp16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||||
|
torch.float16, label, sub_label, cutlass_impl,
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||||
|
raise ValueError("unsupported type")
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype: torch.dtype,
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})")
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# output makers
|
||||||
|
def make_output(data: Iterable[TMeasurement],
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None):
|
||||||
|
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
# argparse runners
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||||
|
n = len(dim_sizes)
|
||||||
|
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||||
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_data.extend(d)
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Cutlass GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
required=True,
|
||||||
|
help="Available options are ['int8', 'fp8']")
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument("--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys())
|
||||||
|
model_parser.add_argument("--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_TP_SIZES)
|
||||||
|
model_parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
37
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
37
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Weight Shapes are in the format
|
||||||
|
# ([K, N], TP_SPLIT_DIM)
|
||||||
|
# Example:
|
||||||
|
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 14336, N = 4096
|
||||||
|
# - TP2 : K = 7168, N = 4096
|
||||||
|
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 4096, N = 6144
|
||||||
|
# - TP4 : K = 4096, N = 1536
|
||||||
|
|
||||||
|
# TP1 shapes
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf": [
|
||||||
|
([4096, 12288], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 22016], 1),
|
||||||
|
([11008, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf": [
|
||||||
|
([5120, 15360], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 27648], 1),
|
||||||
|
([13824, 5120], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import triton
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
|
||||||
get_config_file_name)
|
|
||||||
|
|
||||||
|
|
||||||
def main(model, tp_size, gpu, dtype: str):
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
|
||||||
method = fused_moe
|
|
||||||
for bs in [
|
|
||||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
||||||
2048, 3072, 4096
|
|
||||||
]:
|
|
||||||
run_grid(bs,
|
|
||||||
model=model,
|
|
||||||
method=method,
|
|
||||||
gpu=gpu,
|
|
||||||
tp_size=tp_size,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def run_grid(bs, model, method, gpu, tp_size, dtype: str):
|
|
||||||
if model == '8x7B':
|
|
||||||
d_model = 4096
|
|
||||||
model_intermediate_size = 14336
|
|
||||||
num_layers = 32
|
|
||||||
elif model == '8x22B':
|
|
||||||
d_model = 6144
|
|
||||||
model_intermediate_size = 16384
|
|
||||||
num_layers = 56
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported Mixtral model {model}')
|
|
||||||
num_total_experts = 8
|
|
||||||
top_k = 2
|
|
||||||
# tp_size = 2
|
|
||||||
num_calls = 100
|
|
||||||
|
|
||||||
num_warmup_trials = 1
|
|
||||||
num_trials = 1
|
|
||||||
|
|
||||||
configs = []
|
|
||||||
|
|
||||||
for block_size_n in [32, 64, 128, 256]:
|
|
||||||
for block_size_m in [16, 32, 64, 128, 256]:
|
|
||||||
for block_size_k in [64, 128, 256]:
|
|
||||||
for group_size_m in [1, 16, 32, 64]:
|
|
||||||
for num_warps in [4, 8]:
|
|
||||||
for num_stages in [2, 3, 4, 5]:
|
|
||||||
configs.append({
|
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
|
||||||
"GROUP_SIZE_M": group_size_m,
|
|
||||||
"num_warps": num_warps,
|
|
||||||
"num_stages": num_stages,
|
|
||||||
})
|
|
||||||
|
|
||||||
best_config = None
|
|
||||||
best_time_us = 1e20
|
|
||||||
|
|
||||||
print(f'{tp_size=} {bs=}')
|
|
||||||
|
|
||||||
for config in tqdm(configs):
|
|
||||||
# warmup
|
|
||||||
try:
|
|
||||||
for _ in range(num_warmup_trials):
|
|
||||||
run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# trial
|
|
||||||
for _ in range(num_trials):
|
|
||||||
kernel_dur_ms = run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_dur_us = 1000 * kernel_dur_ms
|
|
||||||
model_dur_ms = kernel_dur_ms * num_layers
|
|
||||||
|
|
||||||
if kernel_dur_us < best_time_us:
|
|
||||||
best_config = config
|
|
||||||
best_time_us = kernel_dur_us
|
|
||||||
|
|
||||||
tqdm.write(
|
|
||||||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
|
||||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
|
||||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
|
||||||
|
|
||||||
print("best_time_us", best_time_us)
|
|
||||||
print("best_config", best_config)
|
|
||||||
|
|
||||||
# holds Dict[str, Dict[str, int]]
|
|
||||||
filename = get_config_file_name(num_total_experts,
|
|
||||||
model_intermediate_size // tp_size,
|
|
||||||
"float8" if dtype == "float8" else None)
|
|
||||||
print(f"writing config to file {filename}")
|
|
||||||
existing_content = {}
|
|
||||||
if os.path.exists(filename):
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
existing_content = json.load(f)
|
|
||||||
existing_content[str(bs)] = best_config
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(existing_content, f, indent=4)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|
||||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
|
||||||
config, dtype: str) -> float:
|
|
||||||
shard_intermediate_size = model_intermediate_size // tp_size
|
|
||||||
|
|
||||||
hidden_states = torch.rand(
|
|
||||||
(bs, d_model),
|
|
||||||
device="cuda:0",
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1 = torch.rand(
|
|
||||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w2 = torch.rand(
|
|
||||||
(num_total_experts, d_model, shard_intermediate_size),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1_scale = None
|
|
||||||
w2_scale = None
|
|
||||||
a1_scale = None
|
|
||||||
a2_scale = None
|
|
||||||
|
|
||||||
if dtype == "float8":
|
|
||||||
w1 = w1.to(torch.float8_e4m3fn)
|
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
|
||||||
w1_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a1_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a2_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
gating_output = F.softmax(torch.rand(
|
|
||||||
(num_calls, bs, num_total_experts),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
for i in range(num_calls):
|
|
||||||
hidden_states = method(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
gating_output=gating_output[i],
|
|
||||||
topk=2,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
override_config=config,
|
|
||||||
use_fp8=dtype == "float8",
|
|
||||||
)
|
|
||||||
end_event.record()
|
|
||||||
end_event.synchronize()
|
|
||||||
|
|
||||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
|
||||||
return dur_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog='benchmark_mixtral_moe',
|
|
||||||
description='Benchmark and tune the fused_moe kernel',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['float8', 'float16'],
|
|
||||||
help='Data type used for fused_moe kernel computations',
|
|
||||||
)
|
|
||||||
parser.add_argument('--model',
|
|
||||||
type=str,
|
|
||||||
default='8x7B',
|
|
||||||
choices=['8x7B', '8x22B'],
|
|
||||||
help='The Mixtral model to benchmark')
|
|
||||||
parser.add_argument('--tp-size',
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help='Tensor paralleli size')
|
|
||||||
parser.add_argument('--gpu',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="GPU ID for benchmarking")
|
|
||||||
args = parser.parse_args()
|
|
||||||
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))
|
|
||||||
322
benchmarks/kernels/benchmark_moe.py
Normal file
322
benchmarks/kernels/benchmark_moe.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
config: Dict[str, int],
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
init_dtype = torch.float16 if use_fp8 else dtype
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
w1 = torch.randn(num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=init_dtype)
|
||||||
|
w2 = torch.randn(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype=init_dtype)
|
||||||
|
gating_output = torch.randn(num_iters,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
if use_fp8:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
override_config=config,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||||
|
# Reduced search space for faster tuning.
|
||||||
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
|
# prune the search space.
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128, 256]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append({
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
})
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> Tuple[Dict[str, int], float]:
|
||||||
|
torch.cuda.manual_seed_all(self.seed)
|
||||||
|
|
||||||
|
dtype_str = "float8" if use_fp8 else None
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||||
|
dtype_str)
|
||||||
|
if op_config is None:
|
||||||
|
config = get_default_config(num_tokens, num_experts,
|
||||||
|
shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype_str)
|
||||||
|
else:
|
||||||
|
config = op_config[min(op_config.keys(),
|
||||||
|
key=lambda x: abs(x - num_tokens))]
|
||||||
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||||
|
shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8)
|
||||||
|
return config, kernel_time
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8,
|
||||||
|
num_iters=10)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def sort_config(config: Dict[str, int]) -> Dict[str, int]:
|
||||||
|
return {
|
||||||
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||||
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||||
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||||
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||||
|
"num_warps": config["num_warps"],
|
||||||
|
"num_stages": config["num_stages"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
configs: Dict[int, Dict[str, int]],
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> None:
|
||||||
|
dtype_str = "float8" if use_fp8 else None
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
|
dtype_str)
|
||||||
|
print(f"Writing best config to {filename}...")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.model)
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
use_fp8 = args.dtype == "fp8"
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
|
2048, 3072, 4096
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
if args.tune:
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
configs = _distribute(
|
||||||
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8, search_space)
|
||||||
|
for batch_size in batch_sizes])
|
||||||
|
best_configs = {
|
||||||
|
M: sort_config(config)
|
||||||
|
for M, config in zip(batch_sizes, configs)
|
||||||
|
}
|
||||||
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
|
else:
|
||||||
|
outputs = _distribute("benchmark",
|
||||||
|
[(batch_size, E, shard_intermediate_size,
|
||||||
|
hidden_size, topk, dtype, use_fp8)
|
||||||
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
|
print(f"Kernel time: {kernel_time:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model",
|
||||||
|
type=str,
|
||||||
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8"],
|
||||||
|
default="auto")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
@@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
|||||||
#
|
#
|
||||||
# Check the compile flags
|
# Check the compile flags
|
||||||
#
|
#
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
"-fopenmp"
|
"-fopenmp"
|
||||||
"-DVLLM_CPU_EXTENSION")
|
"-DVLLM_CPU_EXTENSION")
|
||||||
|
|
||||||
@@ -44,8 +44,8 @@ if (AVX512_FOUND)
|
|||||||
|
|
||||||
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
|
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
|
||||||
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
|
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
|
||||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||||
else()
|
else()
|
||||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||||
@@ -73,7 +73,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/cache.cpp"
|
"csrc/cpu/cache.cpp"
|
||||||
"csrc/cpu/layernorm.cpp"
|
"csrc/cpu/layernorm.cpp"
|
||||||
"csrc/cpu/pos_encoding.cpp"
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
"csrc/cpu/pybind.cpp")
|
"csrc/cpu/torch_bindings.cpp")
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
_C
|
_C
|
||||||
@@ -81,10 +81,10 @@ define_gpu_extension_target(
|
|||||||
LANGUAGE CXX
|
LANGUAGE CXX
|
||||||
SOURCES ${VLLM_EXT_SRC}
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||||
WITH_SOABI
|
USE_SABI 3
|
||||||
|
WITH_SOABI
|
||||||
)
|
)
|
||||||
|
|
||||||
add_custom_target(default)
|
add_custom_target(default)
|
||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
add_dependencies(default _C)
|
add_dependencies(default _C)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
||||||
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
||||||
set(Python_EXECUTABLE ${EXECUTABLE})
|
set(Python_EXECUTABLE ${EXECUTABLE})
|
||||||
find_package(Python COMPONENTS Interpreter Development.Module)
|
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
|
||||||
if (NOT Python_FOUND)
|
if (NOT Python_FOUND)
|
||||||
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
||||||
endif()
|
endif()
|
||||||
@@ -294,6 +294,7 @@ endmacro()
|
|||||||
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
||||||
# LIBRARIES <libraries> - Extra link libraries.
|
# LIBRARIES <libraries> - Extra link libraries.
|
||||||
# WITH_SOABI - Generate library with python SOABI suffix name.
|
# WITH_SOABI - Generate library with python SOABI suffix name.
|
||||||
|
# USE_SABI <version> - Use python stable api <version>
|
||||||
#
|
#
|
||||||
# Note: optimization level/debug info is set via cmake build type.
|
# Note: optimization level/debug info is set via cmake build type.
|
||||||
#
|
#
|
||||||
@@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
|||||||
cmake_parse_arguments(PARSE_ARGV 1
|
cmake_parse_arguments(PARSE_ARGV 1
|
||||||
GPU
|
GPU
|
||||||
"WITH_SOABI"
|
"WITH_SOABI"
|
||||||
"DESTINATION;LANGUAGE"
|
"DESTINATION;LANGUAGE;USE_SABI"
|
||||||
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
||||||
|
|
||||||
# Add hipify preprocessing step when building with HIP/ROCm.
|
# Add hipify preprocessing step when building with HIP/ROCm.
|
||||||
@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
|||||||
set(GPU_WITH_SOABI)
|
set(GPU_WITH_SOABI)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
|
if (GPU_USE_SABI)
|
||||||
|
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
||||||
|
else()
|
||||||
|
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
||||||
|
endif()
|
||||||
|
|
||||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||||
# Make this target dependent on the hipify preprocessor step.
|
# Make this target dependent on the hipify preprocessor step.
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ DEFAULT_CONDA_PATTERNS = {
|
|||||||
"triton",
|
"triton",
|
||||||
"optree",
|
"optree",
|
||||||
"nccl",
|
"nccl",
|
||||||
|
"transformers",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_PIP_PATTERNS = {
|
DEFAULT_PIP_PATTERNS = {
|
||||||
@@ -75,6 +76,7 @@ DEFAULT_PIP_PATTERNS = {
|
|||||||
"optree",
|
"optree",
|
||||||
"onnx",
|
"onnx",
|
||||||
"nccl",
|
"nccl",
|
||||||
|
"transformers",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -601,6 +603,11 @@ Versions of relevant libraries:
|
|||||||
{conda_packages}
|
{conda_packages}
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
# both the above code and the following code use `strip()` to
|
||||||
|
# remove leading/trailing whitespaces, so we need to add a newline
|
||||||
|
# in between to separate the two sections
|
||||||
|
env_info_fmt += "\n"
|
||||||
|
|
||||||
env_info_fmt += """
|
env_info_fmt += """
|
||||||
ROCM Version: {rocm_version}
|
ROCM Version: {rocm_version}
|
||||||
Neuron SDK Version: {neuron_sdk_version}
|
Neuron SDK Version: {neuron_sdk_version}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@@ -808,16 +808,17 @@ void paged_attention_v1(
|
|||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
int num_kv_heads, // [num_heads]
|
int64_t num_kv_heads, // [num_heads]
|
||||||
float scale,
|
double scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int block_size, int max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
const int64_t blocksparse_local_blocks,
|
||||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
@@ -972,16 +973,17 @@ void paged_attention_v2(
|
|||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
int num_kv_heads, // [num_heads]
|
int64_t num_kv_heads, // [num_heads]
|
||||||
float scale,
|
double scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int block_size, int max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
const int64_t blocksparse_local_blocks,
|
||||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||||
@@ -990,4 +992,4 @@ void paged_attention_v2(
|
|||||||
#undef WARP_SIZE
|
#undef WARP_SIZE
|
||||||
#undef MAX
|
#undef MAX
|
||||||
#undef MIN
|
#undef MIN
|
||||||
#undef DIVIDE_ROUND_UP
|
#undef DIVIDE_ROUND_UP
|
||||||
|
|||||||
14
csrc/cache.h
14
csrc/cache.h
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@@ -8,14 +8,18 @@
|
|||||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
const torch::Tensor& block_mapping);
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
std::vector<torch::Tensor>& value_caches,
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||||
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
const torch::Tensor& block_mapping);
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype, const float kv_scale);
|
const std::string& kv_cache_dtype,
|
||||||
|
const double kv_scale);
|
||||||
|
|
||||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
@@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
|||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const float scale, const std::string& kv_cache_dtype);
|
const double scale, const std::string& kv_cache_dtype);
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
std::vector<torch::Tensor>& value_caches,
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||||
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
const torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
int num_layers = key_caches.size();
|
int num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
@@ -255,7 +258,7 @@ void reshape_and_cache(
|
|||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype, const float kv_scale) {
|
const std::string& kv_cache_dtype, const double kv_scale) {
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
@@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
|||||||
|
|
||||||
// Only for testing.
|
// Only for testing.
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const float kv_scale, const std::string& kv_cache_dtype) {
|
const double kv_scale, const std::string& kv_cache_dtype) {
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
|
|||||||
@@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher(
|
|||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
const int64_t blocksparse_local_blocks,
|
||||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
"CPU backend does not support blocksparse attention yet.");
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
@@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher(
|
|||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
const int64_t blocksparse_local_blocks,
|
||||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
"CPU backend does not support blocksparse attention yet.");
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
|
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||||
std::vector<torch::Tensor>& value_caches,
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
const torch::Tensor& mapping_pairs,
|
const torch::Tensor& mapping_pairs,
|
||||||
const int element_num_per_block,
|
const int element_num_per_block,
|
||||||
const int layer_num) {
|
const int layer_num) {
|
||||||
@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
|
|||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
std::vector<torch::Tensor>& value_caches,
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||||
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
const torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
unsigned num_layers = key_caches.size();
|
unsigned num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype, float kv_scale) {
|
const std::string& kv_cache_dtype, double kv_scale) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
#define CPU_TYPES_HPP
|
#define CPU_TYPES_HPP
|
||||||
|
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
namespace vec_op {
|
namespace vec_op {
|
||||||
|
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
float epsilon) {
|
double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor& weight, float epsilon) {
|
torch::Tensor& weight, double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,57 @@ void rotary_embedding_impl(
|
|||||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
|
bool flag = (embed_dim % VEC_ELEM_NUM == 0);
|
||||||
|
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
|
||||||
|
scalar_t* qk) {
|
||||||
|
int j = 0;
|
||||||
|
for (; j < loop_upper; j += VEC_ELEM_NUM) {
|
||||||
|
const int rot_offset = j;
|
||||||
|
const int x_index = rot_offset;
|
||||||
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
const int64_t out_x = token_head + x_index;
|
||||||
|
const int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
const scalar_vec_t cos(cache_ptr + x_index);
|
||||||
|
const scalar_vec_t sin(cache_ptr + y_index);
|
||||||
|
|
||||||
|
const scalar_vec_t q_x(qk + out_x);
|
||||||
|
const scalar_vec_t q_y(qk + out_y);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_cos(cos);
|
||||||
|
vec_op::FP32Vec8 fp32_sin(sin);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_q_x(q_x);
|
||||||
|
vec_op::FP32Vec8 fp32_q_y(q_y);
|
||||||
|
|
||||||
|
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||||
|
scalar_vec_t(out1).save(qk + out_x);
|
||||||
|
|
||||||
|
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||||
|
scalar_vec_t(out2).save(qk + out_y);
|
||||||
|
}
|
||||||
|
if (!flag) {
|
||||||
|
for (; j < embed_dim; ++j) {
|
||||||
|
const int x_index = j;
|
||||||
|
const int y_index = embed_dim + j;
|
||||||
|
|
||||||
|
const int64_t out_x = token_head + x_index;
|
||||||
|
const int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
const float fp32_cos = cache_ptr[x_index];
|
||||||
|
const float fp32_sin = cache_ptr[y_index];
|
||||||
|
|
||||||
|
const float fp32_q_x = qk[out_x];
|
||||||
|
const float fp32_q_y = qk[out_y];
|
||||||
|
|
||||||
|
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||||
|
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
@@ -32,62 +82,13 @@ void rotary_embedding_impl(
|
|||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
const int64_t token_head =
|
const int64_t token_head =
|
||||||
token_idx * query_stride + head_idx * head_size;
|
token_idx * query_stride + head_idx * head_size;
|
||||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
compute_loop(token_head, cache_ptr, query);
|
||||||
const int rot_offset = j;
|
|
||||||
const int x_index = rot_offset;
|
|
||||||
const int y_index = embed_dim + rot_offset;
|
|
||||||
|
|
||||||
const int64_t out_x = token_head + x_index;
|
|
||||||
const int64_t out_y = token_head + y_index;
|
|
||||||
|
|
||||||
const scalar_vec_t cos(cache_ptr + x_index);
|
|
||||||
const scalar_vec_t sin(cache_ptr + y_index);
|
|
||||||
|
|
||||||
const scalar_vec_t q_x(query + out_x);
|
|
||||||
const scalar_vec_t q_y(query + out_y);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_cos(cos);
|
|
||||||
vec_op::FP32Vec8 fp32_sin(sin);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_q_x(q_x);
|
|
||||||
vec_op::FP32Vec8 fp32_q_y(q_y);
|
|
||||||
|
|
||||||
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
|
||||||
scalar_vec_t(out1).save(query + out_x);
|
|
||||||
|
|
||||||
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
|
||||||
scalar_vec_t(out2).save(query + out_y);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < num_kv_heads; ++i) {
|
for (int i = 0; i < num_kv_heads; ++i) {
|
||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
compute_loop(token_head, cache_ptr, key);
|
||||||
const int rot_offset = j;
|
|
||||||
const int x_index = rot_offset;
|
|
||||||
const int y_index = embed_dim + rot_offset;
|
|
||||||
|
|
||||||
const int64_t out_x = token_head + x_index;
|
|
||||||
const int64_t out_y = token_head + y_index;
|
|
||||||
|
|
||||||
const scalar_vec_t cos(cache_ptr + x_index);
|
|
||||||
const scalar_vec_t sin(cache_ptr + y_index);
|
|
||||||
|
|
||||||
const scalar_vec_t k_x(key + out_x);
|
|
||||||
const scalar_vec_t k_y(key + out_y);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_cos(cos);
|
|
||||||
vec_op::FP32Vec8 fp32_sin(sin);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_k_x(k_x);
|
|
||||||
vec_op::FP32Vec8 fp32_k_y(k_y);
|
|
||||||
|
|
||||||
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
|
|
||||||
scalar_vec_t(out1).save(key + out_x);
|
|
||||||
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
|
|
||||||
scalar_vec_t(out2).save(key + out_y);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -167,7 +168,7 @@ void rotary_embedding_gptj_impl(
|
|||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& key, int head_size,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||||
int num_tokens = query.numel() / query.size(-1);
|
int num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
#include "cache.h"
|
|
||||||
#include "cuda_utils.h"
|
|
||||||
#include "ops.h"
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
// vLLM custom ops
|
|
||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
|
||||||
|
|
||||||
// Attention ops
|
|
||||||
ops.def("paged_attention_v1", &paged_attention_v1,
|
|
||||||
"Compute the attention between an input query and the cached "
|
|
||||||
"keys/values using PagedAttention.");
|
|
||||||
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
|
||||||
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
|
||||||
ops.def("gelu_and_mul", &gelu_and_mul,
|
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
|
||||||
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
|
||||||
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
|
||||||
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
|
||||||
ops.def("rms_norm", &rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
|
|
||||||
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
|
|
||||||
// Rotary embedding
|
|
||||||
ops.def("rotary_embedding", &rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
|
|
||||||
// Cache ops
|
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
|
||||||
cache_ops.def("swap_blocks", &swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
|
||||||
cache_ops.def("copy_blocks", ©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
}
|
|
||||||
106
csrc/cpu/torch_bindings.cpp
Normal file
106
csrc/cpu/torch_bindings.cpp
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
#include "cache.h"
|
||||||
|
#include "ops.h"
|
||||||
|
#include "registration.h"
|
||||||
|
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
|
// vLLM custom ops
|
||||||
|
|
||||||
|
// Attention ops
|
||||||
|
// Compute the attention between an input query and the cached keys/values
|
||||||
|
// using PagedAttention.
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v1("
|
||||||
|
" Tensor! out, Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||||
|
" int blocksparse_local_blocks,"
|
||||||
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||||
|
|
||||||
|
// PagedAttention V2.
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v2("
|
||||||
|
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
||||||
|
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||||
|
" int blocksparse_local_blocks,"
|
||||||
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||||
|
|
||||||
|
// Activation ops
|
||||||
|
|
||||||
|
// Activation function used in SwiGLU.
|
||||||
|
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
|
||||||
|
|
||||||
|
// Activation function used in GeGLU with `none` approximation.
|
||||||
|
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
|
||||||
|
|
||||||
|
// Activation function used in GeGLU with `tanh` approximation.
|
||||||
|
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
|
||||||
|
|
||||||
|
// GELU implementation used in GPT-2.
|
||||||
|
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_new", torch::kCPU, &gelu_new);
|
||||||
|
|
||||||
|
// Approximate GELU implementation.
|
||||||
|
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
|
||||||
|
|
||||||
|
// Layernorm
|
||||||
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
|
ops.def(
|
||||||
|
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("rms_norm", torch::kCPU, &rms_norm);
|
||||||
|
|
||||||
|
// In-place fused Add and RMS Normalization.
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
|
||||||
|
"float epsilon) -> ()");
|
||||||
|
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
|
||||||
|
|
||||||
|
// Rotary embedding
|
||||||
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||||
|
ops.def(
|
||||||
|
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||||
|
" Tensor! key, int head_size,"
|
||||||
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||||
|
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||||
|
// Cache ops
|
||||||
|
// Swap in (out) the cache blocks from src to dst.
|
||||||
|
cache_ops.def(
|
||||||
|
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||||
|
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
|
||||||
|
|
||||||
|
// Copy the cache blocks from src to dst.
|
||||||
|
cache_ops.def(
|
||||||
|
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
||||||
|
"block_mapping) -> ()");
|
||||||
|
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||||
|
|
||||||
|
// Reshape the key and value tensors and cache them.
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache(Tensor key, Tensor value,"
|
||||||
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype,"
|
||||||
|
" float kv_scale) -> ()");
|
||||||
|
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
@@ -19,8 +19,12 @@
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
||||||
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||||
|
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
||||||
|
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
||||||
#else
|
#else
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||||
|
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
||||||
|
__shfl_xor(var, lane_mask, width)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||||
|
|
||||||
int get_device_attribute(int attribute, int device_id);
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||||
|
|
||||||
int get_max_shared_memory_per_block_device_attribute(int device_id);
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hip/hip_runtime_api.h>
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
int get_device_attribute(int attribute, int device_id) {
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
|
||||||
int device, value;
|
int device, value;
|
||||||
if (device_id < 0) {
|
if (device_id < 0) {
|
||||||
cudaGetDevice(&device);
|
cudaGetDevice(&device);
|
||||||
@@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) {
|
|||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_max_shared_memory_per_block_device_attribute(int device_id) {
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
|
||||||
int attribute;
|
int64_t attribute;
|
||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include "custom_all_reduce.cuh"
|
#include "custom_all_reduce.cuh"
|
||||||
|
|
||||||
// fake pointer type
|
// fake pointer type, must match fptr_t type in ops.h
|
||||||
using fptr_t = uint64_t;
|
using fptr_t = int64_t;
|
||||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||||
|
|
||||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets, int rank,
|
const std::vector<int64_t>& offsets, int64_t rank,
|
||||||
bool full_nvlink) {
|
bool full_nvlink) {
|
||||||
int world_size = offsets.size();
|
int world_size = offsets.size();
|
||||||
if (world_size > 8)
|
if (world_size > 8)
|
||||||
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
|||||||
t.numel() * t.element_size());
|
t.numel() * t.element_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
|
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
||||||
bool full_nvlink) {
|
bool full_nvlink) {
|
||||||
auto inp_size = inp.numel() * inp.element_size();
|
auto inp_size = inp.numel() * inp.element_size();
|
||||||
// custom allreduce requires input byte size to be multiples of 16
|
// custom allreduce requires input byte size to be multiples of 16
|
||||||
@@ -125,7 +125,7 @@ void dispose(fptr_t _fa) {
|
|||||||
delete fa;
|
delete fa;
|
||||||
}
|
}
|
||||||
|
|
||||||
int meta_size() { return sizeof(vllm::Signal); }
|
int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||||
|
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
@@ -134,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor& t,
|
|||||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
fptr_t _fa) {
|
fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
return fa->get_graph_buffer_ipc_meta();
|
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||||
|
auto options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||||
|
auto handles =
|
||||||
|
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||||
|
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||||
|
return {handles, std::move(offsets)};
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@@ -291,7 +291,7 @@ fused_add_rms_norm_kernel(
|
|||||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
#include "moe_ops.h"
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("topk_softmax", &topk_softmax,
|
|
||||||
"Apply topk softmax to the gating outputs.");
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||||
torch::Tensor& token_expert_indices,
|
torch::Tensor& token_expert_indices,
|
||||||
|
|||||||
@@ -16,18 +16,25 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include "../cuda_compat.h"
|
||||||
|
|
||||||
#include <cub/cub.cuh>
|
#ifndef USE_ROCM
|
||||||
#include <cub/util_type.cuh>
|
#include <cub/util_type.cuh>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/util_type.hpp>
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace moe {
|
namespace moe {
|
||||||
|
|
||||||
static constexpr int WARP_SIZE = 32;
|
|
||||||
|
|
||||||
/// Aligned array type
|
/// Aligned array type
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
@@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
{
|
{
|
||||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
|
||||||
}
|
}
|
||||||
|
|
||||||
// From this point, thread max in all the threads have the max within the row.
|
// From this point, thread max in all the threads have the max within the row.
|
||||||
@@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
{
|
{
|
||||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
|
||||||
}
|
}
|
||||||
|
|
||||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||||
@@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
{
|
{
|
||||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||||
|
|
||||||
// We want lower indices to "win" in every thread so we break ties this way
|
// We want lower indices to "win" in every thread so we break ties this way
|
||||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||||
@@ -383,7 +390,7 @@ struct TopkConstants
|
|||||||
{
|
{
|
||||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||||
@@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
{
|
{
|
||||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||||
|
|
||||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||||
static constexpr int VPT = Constants::VPT;
|
static constexpr int VPT = Constants::VPT;
|
||||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||||
|
|||||||
12
csrc/moe/torch_bindings.cpp
Normal file
12
csrc/moe/torch_bindings.cpp
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#include "registration.h"
|
||||||
|
#include "moe_ops.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||||
|
// Apply topk softmax to the gating outputs.
|
||||||
|
m.def(
|
||||||
|
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||||
|
"token_expert_indices, Tensor gating_output) -> ()");
|
||||||
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
@@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
|||||||
}
|
}
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
int block_size, torch::Tensor sorted_token_ids,
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad) {
|
torch::Tensor num_tokens_post_pad) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|||||||
75
csrc/ops.h
75
csrc/ops.h
@@ -1,40 +1,42 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
const int64_t blocksparse_local_blocks,
|
||||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
const int64_t blocksparse_local_blocks,
|
||||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
float epsilon);
|
double epsilon);
|
||||||
|
|
||||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor& weight, float epsilon);
|
torch::Tensor& weight, double epsilon);
|
||||||
|
|
||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& key, int head_size,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
|
|
||||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& key, int head_size,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox,
|
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||||
int rot_dim,
|
int64_t rot_dim,
|
||||||
torch::Tensor& cos_sin_cache_offsets);
|
torch::Tensor& cos_sin_cache_offsets);
|
||||||
|
|
||||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
@@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
|||||||
|
|
||||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
int split_k_iters);
|
int64_t split_k_iters);
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
torch::Tensor _zeros, int64_t split_k_iters,
|
||||||
int thy);
|
int64_t thx, int64_t thy);
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
@@ -88,14 +90,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
|
||||||
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
float scale);
|
torch::Tensor const& scale);
|
||||||
|
|
||||||
|
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
torch::Tensor& scales);
|
||||||
|
|
||||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor lookup_table);
|
torch::Tensor lookup_table);
|
||||||
@@ -103,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
|||||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
bool use_exllama, int bit);
|
bool use_exllama, int64_t bit);
|
||||||
|
|
||||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||||
|
|
||||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
@@ -113,28 +118,28 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
|||||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
int block_size, torch::Tensor sorted_token_ids,
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad);
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
using fptr_t = uint64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets, int rank,
|
const std::vector<int64_t>& offsets, int64_t rank,
|
||||||
bool full_nvlink);
|
bool full_nvlink);
|
||||||
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
|
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
||||||
bool full_nvlink);
|
bool full_nvlink);
|
||||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||||
torch::Tensor& out);
|
torch::Tensor& out);
|
||||||
void dispose(fptr_t _fa);
|
void dispose(fptr_t _fa);
|
||||||
int meta_size();
|
int64_t meta_size();
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets);
|
const std::vector<int64_t>& offsets);
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
fptr_t _fa);
|
fptr_t _fa);
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>>& offsets);
|
const std::vector<std::vector<int64_t>>& offsets);
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ void rotary_embedding(
|
|||||||
// [num_tokens, num_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
// [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
int64_t num_tokens = query.numel() / query.size(-1);
|
int64_t num_tokens = query.numel() / query.size(-1);
|
||||||
@@ -138,7 +138,7 @@ void rotary_embedding(
|
|||||||
int64_t key_stride = key.stride(-2);
|
int64_t key_stride = key.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
@@ -168,9 +168,9 @@ void batched_rotary_embedding(
|
|||||||
// [num_tokens, num_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
// [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox, int rot_dim,
|
bool is_neox, int64_t rot_dim,
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||||
) {
|
) {
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||||
@@ -180,7 +180,7 @@ void batched_rotary_embedding(
|
|||||||
int64_t key_stride = key.stride(-2);
|
int64_t key_stride = key.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
torch::Tensor indicies, int64_t layer_idx, double scale) {
|
||||||
CHECK_INPUT(y);
|
CHECK_INPUT(y);
|
||||||
CHECK_INPUT(x);
|
CHECK_INPUT(x);
|
||||||
CHECK_INPUT(w);
|
CHECK_INPUT(w);
|
||||||
@@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
|
|
||||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
torch::Tensor indicies, int64_t layer_idx,
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
float scale, int64_t h_in, int64_t h_out,
|
double scale, int64_t h_in, int64_t h_out,
|
||||||
int64_t y_offset) {
|
int64_t y_offset) {
|
||||||
CHECK_INPUT(y);
|
CHECK_INPUT(y);
|
||||||
CHECK_INPUT(x);
|
CHECK_INPUT(x);
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
torch::Tensor indicies, int64_t layer_idx, float scale);
|
torch::Tensor indicies, int64_t layer_idx, double scale);
|
||||||
|
|
||||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
torch::Tensor indicies, int64_t layer_idx,
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
float scale, int64_t h_in, int64_t h_out,
|
double scale, int64_t h_in, int64_t h_out,
|
||||||
int64_t y_offset);
|
int64_t y_offset);
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#include "punica_ops.h"
|
|
||||||
|
|
||||||
//====== pybind ======
|
|
||||||
|
|
||||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
|
||||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
|
||||||
"dispatch_bgmv_low_level");
|
|
||||||
}
|
|
||||||
18
csrc/punica/torch_bindings.cpp
Normal file
18
csrc/punica/torch_bindings.cpp
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#include "registration.h"
|
||||||
|
#include "punica_ops.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
|
||||||
|
"layer_idx, float scale) -> ()");
|
||||||
|
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
|
||||||
|
"Tensor indicies, int layer_idx,"
|
||||||
|
"float scale, int h_in, int h_out,"
|
||||||
|
"int y_offset) -> ()");
|
||||||
|
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
111
csrc/pybind.cpp
111
csrc/pybind.cpp
@@ -1,111 +0,0 @@
|
|||||||
#include "cache.h"
|
|
||||||
#include "cuda_utils.h"
|
|
||||||
#include "ops.h"
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
// vLLM custom ops
|
|
||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
|
||||||
|
|
||||||
// Attention ops
|
|
||||||
ops.def("paged_attention_v1", &paged_attention_v1,
|
|
||||||
"Compute the attention between an input query and the cached "
|
|
||||||
"keys/values using PagedAttention.");
|
|
||||||
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
|
||||||
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
|
||||||
ops.def("gelu_and_mul", &gelu_and_mul,
|
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
|
||||||
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
|
||||||
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
|
||||||
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
|
||||||
ops.def("rms_norm", &rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
|
|
||||||
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
|
|
||||||
// Rotary embedding
|
|
||||||
ops.def("rotary_embedding", &rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
|
|
||||||
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
|
|
||||||
"(supports multiple loras)");
|
|
||||||
|
|
||||||
// Quantization ops
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
|
||||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
|
||||||
ops.def("marlin_gemm", &marlin_gemm,
|
|
||||||
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
|
|
||||||
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
|
||||||
"gptq_marlin Optimized Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
|
|
||||||
"gptq_marlin repack from GPTQ");
|
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
|
||||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
|
|
||||||
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
|
|
||||||
"per-row/column quantization.");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
|
||||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
|
|
||||||
"Compute FP8 quantized tensor for given scaling factor");
|
|
||||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
|
|
||||||
"Compute FP8 quantized tensor and scaling factor");
|
|
||||||
ops.def("moe_align_block_size", &moe_align_block_size,
|
|
||||||
"Aligning the number of tokens to be processed by each expert such "
|
|
||||||
"that it is divisible by the block size.");
|
|
||||||
|
|
||||||
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
|
|
||||||
"Compute int8 quantized tensor for given scaling factor");
|
|
||||||
|
|
||||||
// Cache ops
|
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
|
||||||
cache_ops.def("swap_blocks", &swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
|
||||||
cache_ops.def("copy_blocks", ©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
cache_ops.def("convert_fp8", &convert_fp8,
|
|
||||||
"Convert the key and value cache to fp8 data type");
|
|
||||||
|
|
||||||
// Cuda utils
|
|
||||||
pybind11::module cuda_utils =
|
|
||||||
m.def_submodule("cuda_utils", "vLLM cuda utils");
|
|
||||||
cuda_utils.def("get_device_attribute", &get_device_attribute,
|
|
||||||
"Gets the specified device attribute.");
|
|
||||||
|
|
||||||
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
|
||||||
&get_max_shared_memory_per_block_device_attribute,
|
|
||||||
"Gets the maximum shared memory per block device attribute.");
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
// Custom all-reduce kernels
|
|
||||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
|
||||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
|
||||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
|
||||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
|
||||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
|
||||||
custom_ar.def("dispose", &dispose, "dispose");
|
|
||||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
|
||||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
|
||||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
|
||||||
"get_graph_buffer_ipc_meta");
|
|
||||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
|
||||||
"register_graph_buffers");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "dequantize.cuh"
|
#include "dequantize.cuh"
|
||||||
@@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64)
|
|||||||
|
|
||||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
torch::Tensor _zeros, int64_t split_k_iters,
|
||||||
int thy) {
|
int64_t thx, int64_t thy) {
|
||||||
int in_c = _kernel.size(0);
|
int in_c = _kernel.size(0);
|
||||||
int qout_c = _kernel.size(1);
|
int qout_c = _kernel.size(1);
|
||||||
int out_c = qout_c * 8;
|
int out_c = qout_c * 8;
|
||||||
@@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
|||||||
|
|
||||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
int split_k_iters) {
|
int64_t split_k_iters) {
|
||||||
int num_in_feats = _in_feats.size(0);
|
int num_in_feats = _in_feats.size(0);
|
||||||
int num_in_channels = _in_feats.size(1);
|
int num_in_channels = _in_feats.size(1);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "../../dispatch_utils.h"
|
#include "../../dispatch_utils.h"
|
||||||
|
#include "../../reduction_utils.cuh"
|
||||||
|
|
||||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
@@ -27,33 +28,88 @@ namespace vllm {
|
|||||||
|
|
||||||
template <typename scalar_t, typename scale_type>
|
template <typename scalar_t, typename scale_type>
|
||||||
__global__ void static_scaled_int8_quant_kernel(
|
__global__ void static_scaled_int8_quant_kernel(
|
||||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||||
scale_type scale, const int hidden_size) {
|
scale_type const* scale_ptr, const int hidden_size) {
|
||||||
const int tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
const int token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
|
scale_type const scale = *scale_ptr;
|
||||||
|
|
||||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
out[token_idx * hidden_size + i] =
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
|
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scale_type>
|
||||||
|
__global__ void dynamic_scaled_int8_quant_kernel(
|
||||||
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||||
|
scale_type* scale, const int hidden_size) {
|
||||||
|
int const tid = threadIdx.x;
|
||||||
|
int const token_idx = blockIdx.x;
|
||||||
|
float absmax_val = 0.0f;
|
||||||
|
float const zero = 0.0f;
|
||||||
|
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
float val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||||
|
val = val > zero ? val : -val;
|
||||||
|
absmax_val = val > absmax_val ? val : absmax_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||||
|
__shared__ float block_absmax_val;
|
||||||
|
if (tid == 0) {
|
||||||
|
block_absmax_val = block_absmax_val_maybe;
|
||||||
|
scale[token_idx] = block_absmax_val / 127.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float const tmp_scale = 127.0f / block_absmax_val;
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||||
|
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
float scale) {
|
torch::Tensor const& scale) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
int hidden_size = input.size(-1);
|
TORCH_CHECK(scale.numel() == 1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
|
||||||
dim3 grid(num_tokens);
|
int const hidden_size = input.size(-1);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
dim3 const grid(num_tokens);
|
||||||
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||||
out.data_ptr<int8_t>(), scale,
|
out.data_ptr<int8_t>(),
|
||||||
hidden_size);
|
scale.data_ptr<float>(), hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void dynamic_scaled_int8_quant(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& scales) {
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
|
||||||
|
int const hidden_size = input.size(-1);
|
||||||
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
dim3 const grid(num_tokens);
|
||||||
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||||
|
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||||
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<int8_t>(),
|
||||||
|
scales.data_ptr<float>(), hidden_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,20 +33,27 @@
|
|||||||
//
|
//
|
||||||
// This file is a modified excerpt of
|
// This file is a modified excerpt of
|
||||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||||
// https://github.com/NVIDIA/cutlass It's beem modified to support either
|
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||||
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
|
// It has been modified to support either
|
||||||
// Important because this saves us a factor 4x on the number of kernels
|
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||||
// compiled.
|
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||||
|
// all cases of per-tensor or per-channel/per-token quantization.
|
||||||
|
//
|
||||||
|
// This interface also allows the scales to be passed in as tensors that
|
||||||
|
// consistently reside on the device, which avoids an issue with a previous
|
||||||
|
// implementation where scalars needed to be on the CPU since they
|
||||||
|
// were passed in via float values. This created a potential performance hazard
|
||||||
|
// if scales were initially on the device, and caused torch.compile graph
|
||||||
|
// breaks when moving scales to the CPU.
|
||||||
//
|
//
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
// Turn off clang-format for the entire file to keep it close to upstream
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
||||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
namespace cutlass::epilogue::threadblock {
|
namespace cutlass::epilogue::threadblock {
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
@@ -59,9 +66,11 @@ template<
|
|||||||
>
|
>
|
||||||
struct VisitorRowOrScalarBroadcast {
|
struct VisitorRowOrScalarBroadcast {
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||||
|
// scalar that must be broadcast.
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
Element const* ptr_row = nullptr;
|
Element const* ptr_row = nullptr;
|
||||||
Element null_default = Element(0);
|
bool row_broadcast = true;
|
||||||
StrideMNL dRow = {};
|
StrideMNL dRow = {};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -125,25 +134,25 @@ struct VisitorRowOrScalarBroadcast {
|
|||||||
auto coord_v = filter(tC_cRow);
|
auto coord_v = filter(tC_cRow);
|
||||||
auto dst_v = filter(tC_rRow);
|
auto dst_v = filter(tC_rRow);
|
||||||
|
|
||||||
if (params_ptr->ptr_row) {
|
if (params_ptr->row_broadcast) {
|
||||||
// In this case we are loading from a row vector and broadcasting
|
// In this case we are loading from a row vector and broadcasting
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(src_v); ++i) {
|
for (int i = 0; i < size(src_v); ++i) {
|
||||||
bool guard = get<1>(coord_v(i)) < n;
|
bool guard = get<1>(coord_v(i)) < n;
|
||||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
|
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||||
|
dst_v(i), (void const*)&src_v(i), guard);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// In this case we are loading from a scalar and broadcasting
|
// In this case we are loading from a scalar and broadcasting
|
||||||
VecType filled_vec;
|
VecType filled_vec;
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < VecLength; i++) {
|
for (int i = 0; i < VecLength; i++) {
|
||||||
reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
|
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(src_v); ++i) {
|
for (int i = 0; i < size(src_v); ++i) {
|
||||||
if(get<1>(coord_v(i)) < n)
|
if (get<1>(coord_v(i)) < n) {
|
||||||
{
|
|
||||||
dst_v(i) = filled_vec;
|
dst_v(i) = filled_vec;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,9 +217,11 @@ template<
|
|||||||
>
|
>
|
||||||
struct VisitorColOrScalarBroadcast {
|
struct VisitorColOrScalarBroadcast {
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||||
|
// scalar that must be broadcast.
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
Element const* ptr_col = nullptr;
|
Element const* ptr_col = nullptr;
|
||||||
Element null_default = Element(0);
|
bool col_broadcast = true;
|
||||||
StrideMNL dCol = {};
|
StrideMNL dCol = {};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -230,11 +241,6 @@ struct VisitorColOrScalarBroadcast {
|
|||||||
|
|
||||||
struct SharedStorage { };
|
struct SharedStorage { };
|
||||||
|
|
||||||
// Global load type
|
|
||||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
|
||||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
|
||||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
VisitorColOrScalarBroadcast() { }
|
VisitorColOrScalarBroadcast() { }
|
||||||
|
|
||||||
@@ -267,7 +273,7 @@ struct VisitorColOrScalarBroadcast {
|
|||||||
int m;
|
int m;
|
||||||
|
|
||||||
// This function is modified from VisitorColBroadcast
|
// This function is modified from VisitorColBroadcast
|
||||||
CUTLASS_DEVICE void
|
CUTLASS_DEVICE void
|
||||||
begin_epilogue() {
|
begin_epilogue() {
|
||||||
clear(tC_rCol);
|
clear(tC_rCol);
|
||||||
|
|
||||||
@@ -277,7 +283,7 @@ struct VisitorColOrScalarBroadcast {
|
|||||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params_ptr->ptr_col) {
|
if (params_ptr->col_broadcast) {
|
||||||
// In this case we are loading from a column vector and broadcasting
|
// In this case we are loading from a column vector and broadcasting
|
||||||
copy_if(pred, tC_gCol, tC_rCol);
|
copy_if(pred, tC_gCol, tC_rCol);
|
||||||
} else {
|
} else {
|
||||||
@@ -286,8 +292,8 @@ struct VisitorColOrScalarBroadcast {
|
|||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(dst_v); ++i) {
|
for (int i = 0; i < size(dst_v); ++i) {
|
||||||
if(pred(i)){
|
if (pred(i)) {
|
||||||
dst_v(i) = params_ptr->null_default;
|
dst_v(i) = *(params_ptr->ptr_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||||
|
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
*this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
*POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// This file is a modified excerpt of
|
||||||
|
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||||
|
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||||
|
// It has been modified to support either row/column or scalar broadcasting
|
||||||
|
// where the tensor being loaded from is always passed in via a device pointer.
|
||||||
|
// This lets one compiled kernel handle all cases of per-tensor or
|
||||||
|
// per-channel/per-token quantization.
|
||||||
|
//
|
||||||
|
// This interface also allows the scales to be passed in as tensors that
|
||||||
|
// consistently reside on the device, which avoids an issue with a previous
|
||||||
|
// implementation where scalars needed to be on the CPU since they
|
||||||
|
// were passed in via float values. This created a potential performance hazard
|
||||||
|
// if scales were initially on the device, and caused torch.compile graphs
|
||||||
|
// breaks when moving scales to the CPU.
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Turn off clang-format for the entire file to keep it close to upstream
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/arch/barrier.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::epilogue::fusion {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using namespace detail;
|
||||||
|
|
||||||
|
// Row vector broadcast
|
||||||
|
template<
|
||||||
|
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
||||||
|
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
||||||
|
int Stages,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_0,_1,_0>,
|
||||||
|
int Alignment = 128 / sizeof_bits_v<Element>
|
||||||
|
>
|
||||||
|
struct Sm90RowOrScalarBroadcast {
|
||||||
|
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||||
|
static_assert(
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
||||||
|
|
||||||
|
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
||||||
|
struct SharedStorage {
|
||||||
|
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||||
|
// scalar that must be broadcast, instead of containing a scalar that is
|
||||||
|
// valid if ptr_row is null.
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_row = nullptr;
|
||||||
|
bool row_broadcast = true;
|
||||||
|
StrideMNL dRow = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90RowOrScalarBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params(params),
|
||||||
|
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
||||||
|
|
||||||
|
Params params;
|
||||||
|
Element* smem_row;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_producer_load_needed() const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_C_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_zero() const {
|
||||||
|
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int EpiTiles, class GTensor, class STensor>
|
||||||
|
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
||||||
|
: gRow(cute::forward<GTensor>(gRow)),
|
||||||
|
sRow(cute::forward<STensor>(sRow)),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
GTensor gRow; // (CTA_M,CTA_N)
|
||||||
|
STensor sRow; // (CTA_M,CTA_N,PIPE)
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
||||||
|
if (params.ptr_row == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (issue_tma_load) {
|
||||||
|
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
||||||
|
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
||||||
|
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
||||||
|
// Issue the TMA bulk copy
|
||||||
|
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||||
|
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
|
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||||
|
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||||
|
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||||
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||||
|
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||||
|
|
||||||
|
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||||
|
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
||||||
|
cute::move(gRow), cute::move(sRow), params);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int EpiTiles, class RTensor, class STensor>
|
||||||
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
||||||
|
: tCrRow(cute::forward<RTensor>(tCrRow)),
|
||||||
|
tCsRow(cute::forward<STensor>(tCsRow)),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
||||||
|
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||||
|
if (!params.row_broadcast) {
|
||||||
|
fill(tCrRow, *(params.ptr_row));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
// (only works if 0-strides are in same location, which is by construction)
|
||||||
|
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||||
|
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
|
Array<Element, FragmentSize> frg_row;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
|
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return frg_row;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||||
|
class... Args
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||||
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||||
|
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||||
|
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||||
|
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
||||||
|
|
||||||
|
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||||
|
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
||||||
|
cute::move(tCrRow), cute::move(tCsRow), params);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Column vector broadcast
|
||||||
|
template<
|
||||||
|
int Stages,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_1,_0,_0>,
|
||||||
|
int Alignment = 128 / sizeof_bits_v<Element>
|
||||||
|
>
|
||||||
|
struct Sm90ColOrScalarBroadcast {
|
||||||
|
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||||
|
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||||
|
static_assert(
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||||
|
|
||||||
|
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||||
|
struct SharedStorage { };
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||||
|
// scalar that must be broadcast, instead of containing a scalar that is
|
||||||
|
// valid if ptr_col is null.
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_col = nullptr;
|
||||||
|
bool col_broadcast = true;
|
||||||
|
StrideMNL dCol = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_producer_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_C_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_zero() const {
|
||||||
|
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90ColOrScalarBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params(params) { }
|
||||||
|
|
||||||
|
Params params;
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||||
|
return EmptyProducerLoadCallbacks{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class GTensor, class RTensor>
|
||||||
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
|
||||||
|
: tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||||
|
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin() {
|
||||||
|
if (!params.col_broadcast) {
|
||||||
|
fill(tCrCol, *(params.ptr_col));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
// (only works if 0-strides are in same location, which is by construction)
|
||||||
|
copy_aligned(filter(tCgCol), filter(tCrCol));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
|
Array<Element, FragmentSize> frg_col;
|
||||||
|
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
|
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return frg_col;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||||
|
class... Args
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
|
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||||
|
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
|
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
|
||||||
|
cute::move(tCgCol), cute::move(tCrCol), params);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
@@ -22,17 +22,64 @@
|
|||||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||||
|
|
||||||
#include "cutlass_visitor_2x_broadcast_epilogue.hpp"
|
#include "broadcast_load_epilogue_c2x.hpp"
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This defines a quantized GEMM operation with dequantized output, similar to
|
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||||
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
|
|
||||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||||
|
|
||||||
|
Epilogue functions can be defined to post-process the output before it is
|
||||||
|
written to GPU memory.
|
||||||
|
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||||
|
as well as a static prepare_args function that constructs an
|
||||||
|
EVTCompute::Arguments struct.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||||
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm75_to_sm80 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm80_to_sm89 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm89_to_sm90 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
This epilogue function defines a quantized GEMM operation similar to
|
||||||
|
torch._scaled_mm.
|
||||||
|
|
||||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||||
per-row. B can be quantized per-tensor or per-column.
|
per-row. B can be quantized per-tensor or per-column.
|
||||||
Any combination of per-tensor and per-row or column is supported.
|
Any combination of per-tensor and per-row or column is supported.
|
||||||
@@ -45,30 +92,9 @@ using namespace cute;
|
|||||||
the A and B operands respectively. These scales may be either per-tensor or
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
per row or column.
|
per row or column.
|
||||||
*/
|
*/
|
||||||
|
template <typename ElementD, typename OutputTileThreadMap>
|
||||||
namespace {
|
struct ScaledEpilogue {
|
||||||
|
private:
|
||||||
template <typename Arch, typename ElementAB_, typename ElementD_,
|
|
||||||
typename TileShape, typename WarpShape, typename InstructionShape,
|
|
||||||
int32_t MainLoopStages>
|
|
||||||
struct cutlass_2x_gemm {
|
|
||||||
using ElementAB = ElementAB_;
|
|
||||||
using ElementD = ElementD_;
|
|
||||||
|
|
||||||
using ElementAcc =
|
|
||||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
|
||||||
float>::type;
|
|
||||||
|
|
||||||
using Operator =
|
|
||||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
|
||||||
cutlass::arch::OpMultiplyAddSaturate,
|
|
||||||
cutlass::arch::OpMultiplyAdd>::type;
|
|
||||||
|
|
||||||
using OutputTileThreadMap =
|
|
||||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
|
||||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
|
||||||
>;
|
|
||||||
|
|
||||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||||
@@ -88,20 +114,62 @@ struct cutlass_2x_gemm {
|
|||||||
cutlass::multiplies, ElementD, float,
|
cutlass::multiplies, ElementD, float,
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
using EVTCompute1 =
|
public:
|
||||||
|
using EVTCompute =
|
||||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using ScaleAArgs = typename ScaleA::Arguments;
|
||||||
|
using ScaleBArgs = typename ScaleB::Arguments;
|
||||||
|
|
||||||
|
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||||
|
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||||
|
|
||||||
|
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
||||||
|
return evt_compute_args;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Arch, template <typename> typename ArchGuard,
|
||||||
|
typename ElementAB_, typename ElementD_,
|
||||||
|
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||||
|
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
||||||
|
struct cutlass_2x_gemm {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
|
||||||
|
using ElementAcc =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
|
float>::type;
|
||||||
|
|
||||||
|
using Operator =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||||
|
cutlass::arch::OpMultiplyAddSaturate,
|
||||||
|
cutlass::arch::OpMultiplyAdd>::type;
|
||||||
|
|
||||||
|
using OutputTileThreadMap =
|
||||||
|
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||||
|
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
||||||
|
using EVTCompute = typename Epilogue::EVTCompute;
|
||||||
|
|
||||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||||
Stride<int64_t, Int<1>, Int<0>>>;
|
Stride<int64_t, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
|
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
using RowMajor = typename cutlass::layout::RowMajor;
|
using RowMajor = typename cutlass::layout::RowMajor;
|
||||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||||
using KernelType =
|
using KernelType =
|
||||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||||
float, cutlass::layout::RowMajor, 4,
|
float, cutlass::layout::RowMajor, 4,
|
||||||
@@ -112,17 +180,16 @@ struct cutlass_2x_gemm {
|
|||||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||||
MainLoopStages, Operator,
|
MainLoopStages, Operator,
|
||||||
1 /* epilogue stages */
|
1 /* epilogue stages */
|
||||||
>::GemmKernel;
|
>::GemmKernel>;
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
EpilogueArgs&&... epilogue_params) {
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
@@ -142,29 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|
||||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
|
||||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
|
||||||
|
|
||||||
// If A and B are quantized per-tensor, then these scale tensors are scalars,
|
|
||||||
// and they are passed in via the second argument.
|
|
||||||
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
|
|
||||||
ScaleAArgs a_args = a_scales.numel() == 1
|
|
||||||
? ScaleAArgs{nullptr, a_scales.item<float>(), {}}
|
|
||||||
: ScaleAArgs{a_scales.data_ptr<float>(), {}, {}};
|
|
||||||
|
|
||||||
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
|
|
||||||
ScaleBArgs b_args = b_scales.numel() == 1
|
|
||||||
? ScaleBArgs{nullptr, b_scales.item<float>(), {}}
|
|
||||||
: ScaleBArgs{b_scales.data_ptr<float>(), {}, {}};
|
|
||||||
|
|
||||||
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
|
|
||||||
|
|
||||||
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
|
|
||||||
evt0_compute_args};
|
|
||||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||||
|
|
||||||
|
using Epilogue = typename Gemm::Epilogue;
|
||||||
|
auto evt_args =
|
||||||
|
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
||||||
|
|
||||||
typename Gemm::EVTD::Arguments epilogue_args{
|
typename Gemm::EVTD::Arguments epilogue_args{
|
||||||
evt1_compute_args,
|
evt_args,
|
||||||
d_args,
|
d_args,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -200,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
@@ -214,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
|
||||||
TileShape, WarpShape, InstructionShape, 2>>(
|
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
|
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
||||||
WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
||||||
b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
@@ -241,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(
|
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
|
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
|
||||||
WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::Tensor const& b_scales) {
|
||||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
@@ -269,15 +321,15 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(
|
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
assert(out.dtype() == torch::kFloat16);
|
assert(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(
|
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -285,16 +337,16 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
|
||||||
b_scales);
|
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
|
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
|
||||||
b_scales);
|
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
364
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Normal file
364
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cute/atom/mma_atom.hpp"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/device_memory.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "broadcast_load_epilogue_c3x.hpp"
|
||||||
|
#include "common.hpp"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
|
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||||
|
|
||||||
|
Epilogue functions can be defined to post-process the output before it is
|
||||||
|
written to GPU memory.
|
||||||
|
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||||
|
as well as a static prepare_args function that constructs an
|
||||||
|
EVTCompute::Arguments struct.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
uint32_t next_pow_2(uint32_t const num) {
|
||||||
|
if (num <= 1) return num;
|
||||||
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||||
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm90_or_later : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
This epilogue function defines a quantized GEMM operation similar to
|
||||||
|
torch.scaled_mm_.
|
||||||
|
|
||||||
|
A and B may be both either int8 or fp8_e4m3. A can be
|
||||||
|
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||||
|
Any combination of per-tensor and per-row or column is supported.
|
||||||
|
A and B must have symmetric quantization (zero point == 0).
|
||||||
|
|
||||||
|
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||||
|
scales are applied elementwise with numpy-style broadcasting.
|
||||||
|
|
||||||
|
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||||
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
|
per row or column.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogue {
|
||||||
|
private:
|
||||||
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
|
|
||||||
|
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||||
|
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
using ScaleBDescriptor =
|
||||||
|
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||||
|
EpilogueDescriptor, float>;
|
||||||
|
|
||||||
|
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||||
|
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||||
|
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using ScaleA_Args = typename ScaleA::Arguments;
|
||||||
|
using ScaleB_Args = typename ScaleB::Arguments;
|
||||||
|
|
||||||
|
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||||
|
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||||
|
|
||||||
|
return ArgumentType{a_args, {b_args}};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ElementAB_, typename ElementD_,
|
||||||
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
|
typename EpilogueSchedule>
|
||||||
|
struct cutlass_3x_gemm {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
using ElementAcc =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
|
float>::type;
|
||||||
|
|
||||||
|
using EpilogueDescriptor =
|
||||||
|
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||||
|
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||||
|
ElementD, EpilogueSchedule>;
|
||||||
|
|
||||||
|
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
|
||||||
|
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using ElementC = void;
|
||||||
|
using StrideC = StrideD;
|
||||||
|
|
||||||
|
using EVTCompute = typename Epilogue::EVTCompute;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||||
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||||
|
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
static constexpr size_t CEStorageSize =
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||||
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||||
|
static_cast<int>(CEStorageSize)>;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
ElementAB, cutlass::layout::RowMajor, 16,
|
||||||
|
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||||
|
ElementAcc, TileShape, ClusterShape,
|
||||||
|
Stages,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||||
|
cutlass::gemm::PersistentScheduler>>;
|
||||||
|
|
||||||
|
struct GemmKernel : public KernelType {};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_params) {
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
int32_t m = a.size(0);
|
||||||
|
int32_t n = b.size(1);
|
||||||
|
int32_t k = a.size(1);
|
||||||
|
|
||||||
|
int64_t lda = a.stride(0);
|
||||||
|
int64_t ldb = b.stride(1);
|
||||||
|
int64_t ldc = out.stride(0);
|
||||||
|
|
||||||
|
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using StrideC = typename Gemm::StrideC;
|
||||||
|
|
||||||
|
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||||
|
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||||
|
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||||
|
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||||
|
b_stride};
|
||||||
|
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
Gemm::Epilogue::prepare_args(
|
||||||
|
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||||
|
c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
|
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
prob_shape, mainloop_args, epilogue_args};
|
||||||
|
|
||||||
|
// Launch the CUTLASS GEMM kernel.
|
||||||
|
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
GemmOp gemm_op;
|
||||||
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue, int32_t M>
|
||||||
|
struct sm90_fp8_config {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _8, _1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64 =
|
||||||
|
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM128 =
|
||||||
|
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const mp2 =
|
||||||
|
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
|
if (mp2 <= 64) {
|
||||||
|
// m in [1, 64]
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
// m in (64, 128]
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
// m in (128, inf)
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (a.dtype() == torch::kInt8) {
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using KernelSchedule =
|
||||||
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_caller<cutlass_3x_gemm<
|
||||||
|
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
|
||||||
|
return cutlass_gemm_caller<
|
||||||
|
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_fp8_dispatch<
|
||||||
|
cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t, ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
// clang-format will break include orders
|
|
||||||
// clang-format off
|
|
||||||
#include <cudaTypedefs.h>
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
|
||||||
|
|
||||||
#include "cute/tensor.hpp"
|
|
||||||
#include "cute/atom/mma_atom.hpp"
|
|
||||||
#include "cutlass/numeric_types.h"
|
|
||||||
|
|
||||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
|
||||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
|
||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
|
||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
|
||||||
|
|
||||||
#include "common.hpp"
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
using namespace cute;
|
|
||||||
|
|
||||||
/*
|
|
||||||
This defines a quantized GEMM operation with dequantized output, similar to
|
|
||||||
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
|
|
||||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
|
||||||
|
|
||||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
|
||||||
per-row. B can be quantized per-tensor or per-column.
|
|
||||||
Any combination of per-tensor and per-row or column is supported.
|
|
||||||
A and B must have symmetric quantization (zero point == 0).
|
|
||||||
|
|
||||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
|
||||||
scales are applied elementwise with numpy-style broadcasting.
|
|
||||||
|
|
||||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
|
||||||
the A and B operands respectively. These scales may be either per-tensor or
|
|
||||||
per row or column.
|
|
||||||
*/
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
|
||||||
typename ClusterShape, typename KernelSchedule,
|
|
||||||
typename EpilogueSchedule>
|
|
||||||
struct cutlass_3x_gemm {
|
|
||||||
using ElementAB = ElementAB_;
|
|
||||||
using ElementD = ElementD_;
|
|
||||||
using ElementAcc =
|
|
||||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
|
||||||
float>::type;
|
|
||||||
|
|
||||||
using EpilogueDescriptor =
|
|
||||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
|
||||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
|
||||||
ElementD, EpilogueSchedule>;
|
|
||||||
|
|
||||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
|
||||||
|
|
||||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
|
||||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
|
||||||
|
|
||||||
using ScaleBDescriptor =
|
|
||||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
|
||||||
EpilogueDescriptor, float>;
|
|
||||||
|
|
||||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
|
||||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
|
||||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, float, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTCompute0 =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
|
||||||
|
|
||||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
|
||||||
cutlass::multiplies, ElementD, float,
|
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
||||||
|
|
||||||
using EVTCompute1 =
|
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
|
||||||
|
|
||||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
|
||||||
using ElementC = void;
|
|
||||||
using StrideC = StrideD;
|
|
||||||
|
|
||||||
using CollectiveEpilogue =
|
|
||||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
|
||||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
|
||||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
|
||||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
|
||||||
EpilogueSchedule, EVTCompute1>::CollectiveOp;
|
|
||||||
|
|
||||||
static constexpr size_t CEStorageSize =
|
|
||||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
|
||||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
|
||||||
static_cast<int>(CEStorageSize)>;
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
using CollectiveMainloop =
|
|
||||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
|
||||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
|
||||||
ElementAB, cutlass::layout::RowMajor, 16,
|
|
||||||
ElementAB, cutlass::layout::ColumnMajor, 16,
|
|
||||||
ElementAcc, TileShape, ClusterShape,
|
|
||||||
Stages,
|
|
||||||
KernelSchedule>::CollectiveOp;
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
using KernelType = cutlass::gemm::kernel::GemmUniversal<
|
|
||||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
|
||||||
cutlass::gemm::PersistentScheduler>;
|
|
||||||
|
|
||||||
struct GemmKernel : public KernelType {};
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Gemm>
|
|
||||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
|
||||||
using ElementD = typename Gemm::ElementD;
|
|
||||||
|
|
||||||
int32_t m = a.size(0);
|
|
||||||
int32_t n = b.size(1);
|
|
||||||
int32_t k = a.size(1);
|
|
||||||
|
|
||||||
int64_t lda = a.stride(0);
|
|
||||||
int64_t ldb = b.stride(1);
|
|
||||||
int64_t ldc = out.stride(0);
|
|
||||||
|
|
||||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
|
||||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
|
||||||
using StrideC = typename Gemm::StrideC;
|
|
||||||
|
|
||||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
|
||||||
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
|
||||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
|
||||||
|
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
|
||||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
|
||||||
b_stride};
|
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
|
||||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
|
||||||
|
|
||||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
|
||||||
prob_shape, mainloop_args, epilogue_args};
|
|
||||||
|
|
||||||
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
|
|
||||||
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
|
|
||||||
ScaleA_Args a_args = a_scales.numel() == 1
|
|
||||||
? ScaleA_Args{nullptr, a_scales.item<float>(), {}}
|
|
||||||
: ScaleA_Args{a_scales.data_ptr<float>(), {}, {}};
|
|
||||||
|
|
||||||
ScaleB_Args b_args = b_scales.numel() == 1
|
|
||||||
? ScaleB_Args{nullptr, b_scales.item<float>(), {}}
|
|
||||||
: ScaleB_Args{b_scales.data_ptr<float>(), {}, {}};
|
|
||||||
|
|
||||||
args.epilogue.thread = {a_args, {b_args}};
|
|
||||||
|
|
||||||
// Launch the CUTLASS GEMM kernel.
|
|
||||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
||||||
GemmOp gemm_op;
|
|
||||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
|
||||||
|
|
||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
|
||||||
TORCH_CHECK(workspace_size == 0);
|
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
|
||||||
cutlass::Status status = gemm_op.run(args, stream);
|
|
||||||
CUTLASS_CHECK(status);
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
|
|
||||||
if (a.dtype() == torch::kInt8) {
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
|
||||||
|
|
||||||
using TileShape = Shape<_128, _128, _128>;
|
|
||||||
using ClusterShape = Shape<_1, _2, _1>;
|
|
||||||
using KernelSchedule =
|
|
||||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
|
||||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
|
||||||
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
|
|
||||||
KernelSchedule, EpilogueSchedule>>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
|
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
|
||||||
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
|
|
||||||
KernelSchedule, EpilogueSchedule>>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
||||||
|
|
||||||
using TileShape = Shape<_128, _128, _128>;
|
|
||||||
using ClusterShape = Shape<_1, _2, _1>;
|
|
||||||
using KernelSchedule =
|
|
||||||
typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative;
|
|
||||||
using EpilogueSchedule =
|
|
||||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
|
||||||
cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::bfloat16_t, TileShape,
|
|
||||||
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
|
|
||||||
return cutlass_scaled_mm_dq_dispatcher<
|
|
||||||
cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::half_t, TileShape,
|
|
||||||
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
#include <cudaTypedefs.h>
|
|
||||||
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
||||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
int32_t major_capability;
|
|
||||||
int32_t minor_capability;
|
|
||||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
|
||||||
0);
|
|
||||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
|
||||||
0);
|
|
||||||
int32_t version_num = major_capability * 10 + minor_capability;
|
|
||||||
|
|
||||||
// Checks for conformality
|
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
|
||||||
b.size(1) == c.size(1));
|
|
||||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
|
||||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
|
||||||
|
|
||||||
// Check for strides and alignment
|
|
||||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
|
||||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
|
||||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
|
||||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
|
||||||
|
|
||||||
if (version_num >= 90) {
|
|
||||||
// Hopper
|
|
||||||
|
|
||||||
// Guard against compilation issues for sm90 kernels
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
||||||
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
|
||||||
#else
|
|
||||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
|
||||||
#endif
|
|
||||||
} else if (version_num == 89) {
|
|
||||||
// Ada Lovelace
|
|
||||||
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
|
||||||
} else if (version_num >= 80) {
|
|
||||||
// Ampere
|
|
||||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
|
||||||
} else {
|
|
||||||
// Turing
|
|
||||||
TORCH_CHECK(version_num >= 75);
|
|
||||||
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
75
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Normal file
75
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
int32_t major_capability;
|
||||||
|
int32_t minor_capability;
|
||||||
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||||
|
0);
|
||||||
|
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||||
|
0);
|
||||||
|
int32_t version_num = major_capability * 10 + minor_capability;
|
||||||
|
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
b.size(1) == c.size(1));
|
||||||
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
|
||||||
|
if (version_num >= 90) {
|
||||||
|
// Hopper
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
|
||||||
|
#else
|
||||||
|
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
|
||||||
|
#endif
|
||||||
|
} else if (version_num == 89) {
|
||||||
|
// Ada Lovelace
|
||||||
|
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
|
||||||
|
} else if (version_num >= 80) {
|
||||||
|
// Ampere
|
||||||
|
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
// Turing
|
||||||
|
TORCH_CHECK(version_num >= 75);
|
||||||
|
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -23,8 +23,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||||
const scalar_t val, const float scale) {
|
const scalar_t val, const float inverted_scale) {
|
||||||
float x = static_cast<float>(val) / scale;
|
float x = static_cast<float>(val) * inverted_scale;
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
}
|
}
|
||||||
@@ -71,15 +71,56 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct __align__(8) vec4_t {
|
||||||
|
scalar_t x;
|
||||||
|
scalar_t y;
|
||||||
|
scalar_t z;
|
||||||
|
scalar_t w;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct __align__(4) {
|
||||||
|
c10::Float8_e4m3fn x;
|
||||||
|
c10::Float8_e4m3fn y;
|
||||||
|
c10::Float8_e4m3fn z;
|
||||||
|
c10::Float8_e4m3fn w;
|
||||||
|
}
|
||||||
|
float8x4_t;
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
while (i < num_elems) {
|
|
||||||
out[i] = scaled_fp8_conversion(input[i], *scale);
|
// Invert the scale so that we can use multiplications to avoid expensive
|
||||||
i += blockDim.x * gridDim.x;
|
// division.
|
||||||
|
const float inverted_scale = 1.0f / (*scale);
|
||||||
|
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
const vec4_t<scalar_t>* vectorized_in =
|
||||||
|
reinterpret_cast<const vec4_t<scalar_t>*>(input);
|
||||||
|
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||||
|
|
||||||
|
int num_vec_elems = num_elems >> 2;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
|
||||||
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
|
float8x4_t out_vec;
|
||||||
|
|
||||||
|
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
|
||||||
|
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
|
||||||
|
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
|
||||||
|
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
|
||||||
|
vectorized_out[i] = out_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
|
for (int i = num_vec_elems * 4 + tid; i < num_elems;
|
||||||
|
i += blockDim.x * gridDim.x) {
|
||||||
|
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
|
|||||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
bool use_exllama, int bit) {
|
bool use_exllama, int64_t bit) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||||
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
|
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
|
||||||
@@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) {
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||||
vllm::gptq::shuffle_exllama_weight(
|
vllm::gptq::shuffle_exllama_weight(
|
||||||
(uint32_t*)q_weight.data_ptr(),
|
(uint32_t*)q_weight.data_ptr(),
|
||||||
|
|||||||
@@ -1867,4 +1867,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <torch/python.h>
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
|||||||
@@ -21,29 +21,47 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
template <typename T, int numLanes = WARP_SIZE>
|
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
namespace detail {
|
||||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
|
||||||
"numLanes is not a positive power of 2!");
|
template <typename T>
|
||||||
static_assert(numLanes <= WARP_SIZE);
|
__inline__ __device__ T _max(T a, T b) {
|
||||||
#pragma unroll
|
return max(a, b);
|
||||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
|
||||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
|
||||||
return val;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__inline__ __device__ T _sum(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ReduceFnType = T (*)(T, T);
|
||||||
|
|
||||||
// Helper function to return the next largest power of 2
|
// Helper function to return the next largest power of 2
|
||||||
static constexpr int _nextPow2(unsigned int num) {
|
static constexpr int _nextPow2(unsigned int num) {
|
||||||
if (num <= 1) return num;
|
if (num <= 1) return num;
|
||||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
template <typename T, int numLanes = WARP_SIZE>
|
||||||
|
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
|
||||||
|
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||||
|
"numLanes is not a positive power of 2!");
|
||||||
|
static_assert(numLanes <= WARP_SIZE);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||||
|
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int maxBlockSize = 1024>
|
template <typename T, int maxBlockSize = 1024>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
|
||||||
static_assert(maxBlockSize <= 1024);
|
static_assert(maxBlockSize <= 1024);
|
||||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduce<T>(val, fn);
|
||||||
// Calculates max number of lanes that need to participate in the last
|
// Calculates max number of lanes that need to participate in the last
|
||||||
// warpReduce
|
// warpReduce
|
||||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
@@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||||||
|
|
||||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||||
: (T)(0.0f);
|
: (T)(0.0f);
|
||||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
|
||||||
} else {
|
} else {
|
||||||
// A single warpReduce is equal to blockReduce
|
// A single warpReduce is equal to blockReduce
|
||||||
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
|
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
|
||||||
}
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int maxBlockSize = 1024>
|
||||||
|
__inline__ __device__ T blockReduceMax(T val) {
|
||||||
|
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int maxBlockSize = 1024>
|
||||||
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
|
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
22
csrc/registration.h
Normal file
22
csrc/registration.h
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#define _CONCAT(A, B) A##B
|
||||||
|
#define CONCAT(A, B) _CONCAT(A, B)
|
||||||
|
|
||||||
|
#define _STRINGIFY(A) #A
|
||||||
|
#define STRINGIFY(A) _STRINGIFY(A)
|
||||||
|
|
||||||
|
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||||
|
// could be a macro instead of a literal token.
|
||||||
|
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||||
|
|
||||||
|
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||||
|
// via python's import statement.
|
||||||
|
#define REGISTER_EXTENSION(NAME) \
|
||||||
|
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||||
|
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
||||||
|
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||||
|
return PyModule_Create(&module); \
|
||||||
|
}
|
||||||
283
csrc/torch_bindings.cpp
Normal file
283
csrc/torch_bindings.cpp
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
#include "cache.h"
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
#include "ops.h"
|
||||||
|
#include "registration.h"
|
||||||
|
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
// Note on op signatures:
|
||||||
|
// The X_meta signatures are for the meta functions corresponding to op X.
|
||||||
|
// They must be kept in sync with the signature for X. Generally, only
|
||||||
|
// functions that return Tensors require a meta function.
|
||||||
|
//
|
||||||
|
// See the following links for detailed docs on op registration and function
|
||||||
|
// schemas.
|
||||||
|
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
||||||
|
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
|
// vLLM custom ops
|
||||||
|
|
||||||
|
// Attention ops
|
||||||
|
// Compute the attention between an input query and the cached
|
||||||
|
// keys/values using PagedAttention.
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v1("
|
||||||
|
" Tensor! out, Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||||
|
" int blocksparse_local_blocks,"
|
||||||
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
||||||
|
|
||||||
|
// PagedAttention V2.
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v2("
|
||||||
|
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
||||||
|
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||||
|
" int blocksparse_local_blocks,"
|
||||||
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||||
|
|
||||||
|
// Activation ops
|
||||||
|
// Activation function used in SwiGLU.
|
||||||
|
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||||
|
|
||||||
|
// Activation function used in GeGLU with `none` approximation.
|
||||||
|
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||||
|
|
||||||
|
// Activation function used in GeGLU with `tanh` approximation.
|
||||||
|
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||||
|
|
||||||
|
// GELU implementation used in GPT-2.
|
||||||
|
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
||||||
|
|
||||||
|
// Approximate GELU implementation.
|
||||||
|
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
|
||||||
|
|
||||||
|
// Layernorm
|
||||||
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
|
ops.def(
|
||||||
|
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
||||||
|
|
||||||
|
// In-place fused Add and RMS Normalization.
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
|
||||||
|
"float epsilon) -> ()");
|
||||||
|
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||||
|
|
||||||
|
// Rotary embedding
|
||||||
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||||
|
ops.def(
|
||||||
|
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||||
|
" Tensor! key, int head_size,"
|
||||||
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||||
|
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||||
|
|
||||||
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
|
||||||
|
// (supports multiple loras).
|
||||||
|
ops.def(
|
||||||
|
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
||||||
|
" Tensor! key, int head_size,"
|
||||||
|
" Tensor cos_sin_cache, bool is_neox,"
|
||||||
|
" int rot_dim,"
|
||||||
|
" Tensor cos_sin_cache_offsets) -> ()");
|
||||||
|
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
|
||||||
|
|
||||||
|
// Quantization ops
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// Quantized GEMM for AQLM.
|
||||||
|
ops.def("aqlm_gemm", &aqlm_gemm);
|
||||||
|
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
||||||
|
|
||||||
|
// Decompression method for AQLM.
|
||||||
|
ops.def("aqlm_dequant", &aqlm_dequant);
|
||||||
|
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
||||||
|
|
||||||
|
// Quantized GEMM for AWQ.
|
||||||
|
ops.def("awq_gemm", &awq_gemm);
|
||||||
|
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||||
|
|
||||||
|
// Dequantization for AWQ.
|
||||||
|
ops.def("awq_dequantize", &awq_dequantize);
|
||||||
|
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||||
|
|
||||||
|
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||||
|
ops.def("marlin_gemm", &marlin_gemm);
|
||||||
|
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
|
||||||
|
|
||||||
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||||
|
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
||||||
|
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
|
||||||
|
|
||||||
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||||
|
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||||
|
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
||||||
|
|
||||||
|
// gptq_marlin repack from GPTQ.
|
||||||
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
||||||
|
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||||
|
|
||||||
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
|
// quantization.
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||||
|
" Tensor b, Tensor a_scales,"
|
||||||
|
" Tensor b_scales) -> ()");
|
||||||
|
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Quantized GEMM for GPTQ.
|
||||||
|
ops.def("gptq_gemm", &gptq_gemm);
|
||||||
|
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||||
|
|
||||||
|
// Post processing for GPTQ.
|
||||||
|
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||||
|
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||||
|
|
||||||
|
// Quantized GEMM for SqueezeLLM.
|
||||||
|
ops.def(
|
||||||
|
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
|
||||||
|
"lookup_table) -> ()");
|
||||||
|
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
|
||||||
|
|
||||||
|
// Compute FP8 quantized tensor for given scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
||||||
|
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||||
|
|
||||||
|
// Compute FP8 quantized tensor and scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
||||||
|
|
||||||
|
// Aligning the number of tokens to be processed by each expert such
|
||||||
|
// that it is divisible by the block size.
|
||||||
|
ops.def(
|
||||||
|
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||||
|
" int block_size, Tensor! sorted_token_ids,"
|
||||||
|
" Tensor! experts_ids,"
|
||||||
|
" Tensor! num_tokens_post_pad) -> ()");
|
||||||
|
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
|
// Compute int8 quantized tensor for given scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
||||||
|
|
||||||
|
// Compute int8 quantized tensor and scaling factor
|
||||||
|
ops.def(
|
||||||
|
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||||
|
&dynamic_scaled_int8_quant);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||||
|
// Cache ops
|
||||||
|
// Swap in (out) the cache blocks from src to dst.
|
||||||
|
cache_ops.def(
|
||||||
|
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||||
|
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
||||||
|
|
||||||
|
// Copy the cache blocks from src to dst.
|
||||||
|
cache_ops.def(
|
||||||
|
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
||||||
|
"block_mapping) -> ()");
|
||||||
|
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
||||||
|
|
||||||
|
// Reshape the key and value tensors and cache them.
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache(Tensor key, Tensor value,"
|
||||||
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype,"
|
||||||
|
" float kv_scale) -> ()");
|
||||||
|
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
||||||
|
|
||||||
|
// Reshape the key and value tensors and cache them.
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache_flash(Tensor key, Tensor value,"
|
||||||
|
" Tensor! key_cache,"
|
||||||
|
" Tensor! value_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype) -> ()");
|
||||||
|
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||||
|
&reshape_and_cache_flash);
|
||||||
|
|
||||||
|
// Convert the key and value cache to fp8 data type.
|
||||||
|
cache_ops.def(
|
||||||
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
|
||||||
|
"kv_cache_dtype) -> ()");
|
||||||
|
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||||
|
// Cuda utils
|
||||||
|
|
||||||
|
// Gets the specified device attribute.
|
||||||
|
cuda_utils.def("get_device_attribute", &get_device_attribute);
|
||||||
|
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
|
||||||
|
|
||||||
|
// Gets the maximum shared memory per block device attribute.
|
||||||
|
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
||||||
|
&get_max_shared_memory_per_block_device_attribute);
|
||||||
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||||
|
torch::kCUDA,
|
||||||
|
&get_max_shared_memory_per_block_device_attribute);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||||
|
// Custom all-reduce kernels
|
||||||
|
custom_ar.def("init_custom_ar", &init_custom_ar);
|
||||||
|
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||||
|
|
||||||
|
custom_ar.def("should_custom_ar", &should_custom_ar);
|
||||||
|
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
|
||||||
|
|
||||||
|
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||||
|
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||||
|
|
||||||
|
custom_ar.def(
|
||||||
|
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
|
||||||
|
"()");
|
||||||
|
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||||
|
|
||||||
|
custom_ar.def("dispose", &dispose);
|
||||||
|
custom_ar.impl("dispose", torch::kCPU, &dispose);
|
||||||
|
|
||||||
|
custom_ar.def("meta_size", &meta_size);
|
||||||
|
custom_ar.impl("meta_size", torch::kCPU, &meta_size);
|
||||||
|
|
||||||
|
custom_ar.def("register_buffer", ®ister_buffer);
|
||||||
|
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||||
|
|
||||||
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||||
|
custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
|
||||||
|
&get_graph_buffer_ipc_meta);
|
||||||
|
|
||||||
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||||
|
custom_ar.impl("register_graph_buffers", torch::kCPU,
|
||||||
|
®ister_graph_buffers);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
110
docs/source/automatic_prefix_caching/apc.rst
Normal file
110
docs/source/automatic_prefix_caching/apc.rst
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
.. _apc:
|
||||||
|
|
||||||
|
Introduction
|
||||||
|
============
|
||||||
|
|
||||||
|
What is Automatic Prefix Caching
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Technical details on how vLLM implements APC are in the next page.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Enabling APC in vLLM
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
Set ``enable_prefix_caching=True`` in vLLM engine to enable APC. Here is an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import time
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
|
||||||
|
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
|
||||||
|
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|
||||||
|
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
|
||||||
|
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
|
||||||
|
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
|
||||||
|
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
|
||||||
|
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
|
||||||
|
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
|
||||||
|
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
|
||||||
|
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
|
||||||
|
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
|
||||||
|
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
|
||||||
|
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
|
||||||
|
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
|
||||||
|
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
|
||||||
|
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
|
||||||
|
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
|
||||||
|
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
|
||||||
|
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
|
||||||
|
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
|
||||||
|
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
|
||||||
|
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
|
||||||
|
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
|
||||||
|
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
|
||||||
|
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
|
||||||
|
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
|
||||||
|
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
|
||||||
|
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
|
||||||
|
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
|
||||||
|
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
|
||||||
|
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
|
||||||
|
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
|
||||||
|
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_generation_time(llm, sampling_params, prompts):
|
||||||
|
# time the generation
|
||||||
|
start_time = time.time()
|
||||||
|
output = llm.generate(prompts, sampling_params=sampling_params)
|
||||||
|
end_time = time.time()
|
||||||
|
# print the output and generation time
|
||||||
|
print(f"Output: {output[0].outputs[0].text}")
|
||||||
|
print(f"Generation time: {end_time - start_time} seconds.")
|
||||||
|
|
||||||
|
|
||||||
|
# set enable_prefix_caching=True to enable APC
|
||||||
|
llm = LLM(
|
||||||
|
model='lmsys/longchat-13b-16k',
|
||||||
|
enable_prefix_caching=True
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||||
|
|
||||||
|
# Querying the age of John Doe
|
||||||
|
get_generation_time(
|
||||||
|
llm,
|
||||||
|
sampling_params,
|
||||||
|
LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Querying the age of Zack Blue
|
||||||
|
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
|
||||||
|
get_generation_time(
|
||||||
|
llm,
|
||||||
|
sampling_params,
|
||||||
|
LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
|
||||||
|
)
|
||||||
|
|
||||||
|
Example workloads
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We describe two example workloads, where APC can provide huge performance benefit:
|
||||||
|
|
||||||
|
- Long document query, where the user repeatedly queries the same long document (e.g. software manual or annual report) with different queries. In this case, instead of processing the long document again and again, APC allows vLLM to process this long document *only once*, and all future requests can avoid recomputing this long document by reusing its KV cache. This allows vLLM to serve future requests with much higher throughput and much lower latency.
|
||||||
|
- Multi-round conversation, where the user may chat with the application multiple times in the same chatting session. In this case, instead of processing the whole chatting history again and again, APC allows vLLM to reuse the processing results of the chat history across all future rounds of conversation, allowing vLLM to serve future requests with much higher throughput and much lower latency.
|
||||||
|
|
||||||
|
|
||||||
|
Limits
|
||||||
|
------
|
||||||
|
APC in general does not reduce the performance of vLLM. With that being said, APC only reduces the time of processing the queries (the prefilling phase) and does not reduce the time of generating new tokens (the decoding phase). So APC does not bring performance gain when vLLM spends most of the time generating answers to the queries (e.g. when the length of the answer is long), or new queries do not share the same prefix with any of existing queries (so that the computation cannot be reused).
|
||||||
43
docs/source/automatic_prefix_caching/details.md
Normal file
43
docs/source/automatic_prefix_caching/details.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Implementation
|
||||||
|
|
||||||
|
The core idea of PagedAttention is to partition the KV cache of each request into KV Blocks. Each block contains the attention keys and values for a fixed number of tokens. The PagedAttention algorithm allows these blocks to be stored in non-contiguous physical memory so that we can eliminate memory fragmentation by allocating the memory on demand.
|
||||||
|
|
||||||
|
To automatically cache the KV cache, we utilize the following key observation: Each KV block can be uniquely identified by the tokens within the block and the tokens in the prefix before the block.
|
||||||
|
|
||||||
|
```
|
||||||
|
Block 1 Block 2 Block 3
|
||||||
|
[A gentle breeze stirred] [the leaves as children] [laughed in the distance]
|
||||||
|
Block 1: |<--- block tokens ---->|
|
||||||
|
Block 2: |<------- prefix ------>| |<--- block tokens --->|
|
||||||
|
Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
In the example above, the KV cache in the first block can be uniquely identified with the tokens “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the following one-to-one mapping:
|
||||||
|
|
||||||
|
```
|
||||||
|
hash(prefix tokens + block tokens) <--> KV Block
|
||||||
|
```
|
||||||
|
|
||||||
|
With this mapping, we can add another indirection in vLLM’s KV cache management. Previously, each sequence in vLLM maintained a mapping from their logical KV blocks to physical blocks. To achieve automatic caching of KV blocks, we map the logical KV blocks to their hash value and maintain a global hash table of all the physical blocks. In this way, all the KV blocks sharing the same hash value (e.g., shared prefix blocks across two requests) can be mapped to the same physical block and share the memory space.
|
||||||
|
|
||||||
|
|
||||||
|
This design achieves automatic prefix caching without the need of maintaining a tree structure among the KV blocks. More specifically, all of the blocks are independent of each other and can be allocated and freed by itself, which enables us to manages the KV cache as ordinary caches in operating system.
|
||||||
|
|
||||||
|
|
||||||
|
# Generalized Caching Policy
|
||||||
|
|
||||||
|
Keeping all the KV blocks in a hash table enables vLLM to cache KV blocks from earlier requests to save memory and accelerate the computation of future requests. For example, if a new request shares the system prompt with the previous request, the KV cache of the shared prompt can directly be used for the new request without recomputation. However, the total KV cache space is limited and we have to decide which KV blocks to keep or evict when the cache is full.
|
||||||
|
|
||||||
|
Managing KV cache with a hash table allows us to implement flexible caching policies. As an example, in current vLLM, we implement the following eviction policy:
|
||||||
|
|
||||||
|
* When there are no free blocks left, we will evict a KV block with reference count (i.e., number of current requests using the block) equals 0.
|
||||||
|
* If there are multiple blocks with reference count equals to 0, we prioritize to evict the least recently used block (LRU).
|
||||||
|
* If there are multiple blocks whose last access time are the same, we prioritize the eviction of the block that is at the end of the longest prefix (i.e., has the maximum number of blocks before it).
|
||||||
|
|
||||||
|
Note that this eviction policy effectively implements the exact policy as in [RadixAttention](https://lmsys.org/blog/2024-01-17-sglang/) when applied to models with full attention, which prioritizes to evict reference count zero and least recent used leaf nodes in the prefix tree.
|
||||||
|
|
||||||
|
However, the hash-based KV cache management gives us the flexibility to handle more complicated serving scenarios and implement more complicated eviction policies beyond the policy above:
|
||||||
|
|
||||||
|
- Multi-LoRA serving. When serving requests for multiple LoRA adapters, we can simply let the hash of each KV block to also include the LoRA ID the request is querying for to enable caching for all adapters. In this way, we can jointly manage the KV blocks for different adapters, which simplifies the system implementation and improves the global cache hit rate and efficiency.
|
||||||
|
- Multi-modal models. When the user input includes more than just discrete tokens, we can use different hashing methods to handle the caching of inputs of different modalities. For example, perceptual hashing for images to cache similar input images.
|
||||||
@@ -5,6 +5,7 @@ vLLM Meetups
|
|||||||
|
|
||||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||||
|
|
||||||
|
- `The fourth vLLM meetup <https://lu.ma/agivllm>`__, with Cloudflare and BentoML, June 11th 2024. `[Slides] <https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing>`__
|
||||||
- `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__
|
- `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__
|
||||||
- `The second vLLM meetup <https://lu.ma/ygxbpzhl>`__, with IBM Research, January 31st 2024. `[Slides] <https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing>`__ `[Video (vLLM Update)] <https://youtu.be/Y0C-DUvEnZQ>`__ `[Video (IBM Research & torch.compile)] <https://youtu.be/m0dMtFLI-dg>`__
|
- `The second vLLM meetup <https://lu.ma/ygxbpzhl>`__, with IBM Research, January 31st 2024. `[Slides] <https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing>`__ `[Video (vLLM Update)] <https://youtu.be/Y0C-DUvEnZQ>`__ `[Video (IBM Research & torch.compile)] <https://youtu.be/m0dMtFLI-dg>`__
|
||||||
- `The first vLLM meetup <https://lu.ma/first-vllm-meetup>`__, with a16z, October 5th 2023. `[Slides] <https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing>`__
|
- `The first vLLM meetup <https://lu.ma/first-vllm-meetup>`__, with a16z, October 5th 2023. `[Slides] <https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing>`__
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ vLLM is a community project. Our compute resources for development and testing a
|
|||||||
- Replicate
|
- Replicate
|
||||||
- Roblox
|
- Roblox
|
||||||
- RunPod
|
- RunPod
|
||||||
|
- Sequoia Capital
|
||||||
- Trainy
|
- Trainy
|
||||||
- UC Berkeley
|
- UC Berkeley
|
||||||
- UC San Diego
|
- UC San Diego
|
||||||
|
|
||||||
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
||||||
|
|||||||
@@ -90,7 +90,9 @@ autodoc_mock_imports = [
|
|||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"vllm.cuda_utils",
|
"vllm.cuda_utils",
|
||||||
"vllm._C",
|
"vllm._C",
|
||||||
|
"PIL",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
'triton',
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"tensorizer",
|
"tensorizer",
|
||||||
]
|
]
|
||||||
@@ -116,12 +118,13 @@ class MockedClassDocumenter(autodoc.ClassDocumenter):
|
|||||||
autodoc.ClassDocumenter = MockedClassDocumenter
|
autodoc.ClassDocumenter = MockedClassDocumenter
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
'python': ('https://docs.python.org/3', None),
|
"python": ("https://docs.python.org/3", None),
|
||||||
'typing_extensions':
|
"typing_extensions":
|
||||||
('https://typing-extensions.readthedocs.io/en/latest', None),
|
("https://typing-extensions.readthedocs.io/en/latest", None),
|
||||||
'numpy': ('https://numpy.org/doc/stable', None),
|
"pillow": ("https://pillow.readthedocs.io/en/stable", None),
|
||||||
'torch': ('https://pytorch.org/docs/stable', None),
|
"numpy": ("https://numpy.org/doc/stable", None),
|
||||||
'psutil': ('https://psutil.readthedocs.io/en/stable', None),
|
"torch": ("https://pytorch.org/docs/stable", None),
|
||||||
|
"psutil": ("https://psutil.readthedocs.io/en/stable", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
autodoc_preserve_defaults = True
|
autodoc_preserve_defaults = True
|
||||||
|
|||||||
51
docs/source/dev/multimodal/multimodal_index.rst
Normal file
51
docs/source/dev/multimodal/multimodal_index.rst
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
Multi-Modality
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: vllm.multimodal
|
||||||
|
|
||||||
|
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||||
|
|
||||||
|
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
|
||||||
|
which allows you to pass in multi-modal input alongside text and token prompts.
|
||||||
|
|
||||||
|
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
|
||||||
|
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
|
||||||
|
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
|
||||||
|
|
||||||
|
.. contents::
|
||||||
|
:local:
|
||||||
|
:backlinks: none
|
||||||
|
|
||||||
|
Module Contents
|
||||||
|
+++++++++++++++
|
||||||
|
|
||||||
|
.. automodule:: vllm.multimodal
|
||||||
|
|
||||||
|
Registry
|
||||||
|
--------
|
||||||
|
|
||||||
|
.. data:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
The global :class:`MultiModalRegistry` which is used by model runners.
|
||||||
|
|
||||||
|
.. autoclass:: vllm.multimodal.MultiModalRegistry
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
Base Classes
|
||||||
|
------------
|
||||||
|
|
||||||
|
.. autoclass:: vllm.multimodal.MultiModalData
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: vllm.multimodal.MultiModalPlugin
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
Image Classes
|
||||||
|
-------------
|
||||||
|
|
||||||
|
.. automodule:: vllm.multimodal.image
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -10,6 +10,7 @@ Table of contents:
|
|||||||
#. :ref:`Requirements <cpu_backend_requirements>`
|
#. :ref:`Requirements <cpu_backend_requirements>`
|
||||||
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
|
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
|
||||||
#. :ref:`Build from source <build_cpu_backend_from_source>`
|
#. :ref:`Build from source <build_cpu_backend_from_source>`
|
||||||
|
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
|
||||||
#. :ref:`Performance tips <cpu_backend_performance_tips>`
|
#. :ref:`Performance tips <cpu_backend_performance_tips>`
|
||||||
|
|
||||||
.. _cpu_backend_requirements:
|
.. _cpu_backend_requirements:
|
||||||
@@ -18,7 +19,7 @@ Requirements
|
|||||||
------------
|
------------
|
||||||
|
|
||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Compiler: gcc/g++>=12.3.0 (recommended)
|
* Compiler: gcc/g++>=12.3.0 (optional, recommended)
|
||||||
* Instruction set architecture (ISA) requirement: AVX512 is required.
|
* Instruction set architecture (ISA) requirement: AVX512 is required.
|
||||||
|
|
||||||
.. _cpu_backend_quick_start_dockerfile:
|
.. _cpu_backend_quick_start_dockerfile:
|
||||||
@@ -41,7 +42,7 @@ Quick start using Dockerfile
|
|||||||
Build from source
|
Build from source
|
||||||
-----------------
|
-----------------
|
||||||
|
|
||||||
- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
|
- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
@@ -54,7 +55,7 @@ Build from source
|
|||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ pip install --upgrade pip
|
$ pip install --upgrade pip
|
||||||
$ pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
$ pip install wheel packaging ninja "setuptools>=49.4.0" numpy
|
||||||
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
- Finally, build and install vLLM CPU backend:
|
- Finally, build and install vLLM CPU backend:
|
||||||
@@ -70,6 +71,15 @@ Build from source
|
|||||||
|
|
||||||
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
|
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
|
||||||
|
|
||||||
|
.. _ipex_guidance:
|
||||||
|
|
||||||
|
Intel Extension for PyTorch
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
|
||||||
|
|
||||||
|
- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.
|
||||||
|
|
||||||
.. _cpu_backend_performance_tips:
|
.. _cpu_backend_performance_tips:
|
||||||
|
|
||||||
Performance tips
|
Performance tips
|
||||||
@@ -77,6 +87,15 @@ Performance tips
|
|||||||
|
|
||||||
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
|
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
|
||||||
|
|
||||||
|
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library
|
||||||
|
$ find / -name *libtcmalloc* # find the dynamic link library path
|
||||||
|
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
|
||||||
|
$ python examples/offline_inference.py # run vLLM
|
||||||
|
|
||||||
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
|
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
|
||||||
|
|
||||||
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
|
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
|
||||||
|
|||||||
45
docs/source/getting_started/debugging.rst
Normal file
45
docs/source/getting_started/debugging.rst
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
.. _debugging:
|
||||||
|
|
||||||
|
Debugging Tips
|
||||||
|
===============
|
||||||
|
|
||||||
|
Debugging hang/crash issues
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
When an vLLM instance hangs or crashes, it is very difficult to debug the issue. But wait a minute, it is also possible that vLLM is doing something that indeed takes a long time:
|
||||||
|
|
||||||
|
- **Downloading a model**: Do you have the model already downloaded in your disk? If not, vLLM will download the model from the internet, which can take a long time. Be sure to check the internet connection. It would be better to download the model first using `huggingface-cli <https://huggingface.co/docs/huggingface_hub/en/guides/cli>`_ and then use the local path to the model. This way, you can isolate the issue.
|
||||||
|
- **Loading the model from disk**: If the model is large, it can take a long time to load the model from disk. Please take care of the location you store the model. Some clusters have shared filesystems across nodes, e.g. distributed filesystem or network filesystem, which can be slow. It would be better to store the model in a local disk. In addition, please also watch the CPU memory usage. When the model is too large, it might take much CPU memory, which can slow down the operating system because it needs to frequently swap memory between the disk and the memory.
|
||||||
|
- **Tensor parallel inference**: If the model is too large to fit in a single GPU, you might want to use tensor parallelism to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `the provided script <https://docs.vllm.ai/en/latest/getting_started/examples/save_sharded_state.html>`_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||||
|
|
||||||
|
If you have already taken care of the above issues, but the vLLM instance still hangs, with CPU and GPU utilization at near zero, it is likely that the vLLM instance is stuck somewhere. Here are some tips to help debug the issue:
|
||||||
|
|
||||||
|
- Set the environment variable ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging.
|
||||||
|
- Set the environment variable ``export CUDA_LAUNCH_BLOCKING=1`` to know exactly which CUDA kernel is causing the trouble.
|
||||||
|
- Set the environment variable ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
|
||||||
|
- Set the environment variable ``export VLLM_TRACE_FUNCTION=1``. All the function calls in vLLM will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
vLLM function tracing will generate a lot of logs and slow down the system. Only use it for debugging purposes.
|
||||||
|
|
||||||
|
With more logging, hopefully you can find the root cause of the issue.
|
||||||
|
|
||||||
|
Here are some common issues that can cause hangs:
|
||||||
|
|
||||||
|
- **Incorrect network setup**: The vLLM instance cannot get the correct IP address. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``.
|
||||||
|
- **Incorrect hardware/driver**: GPU communication cannot be established. You can run the following sanity check script to see if the GPU communication is working correctly.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# save it as `test.py` , and run it with `NCCL_DEBUG=TRACE torchrun --nproc-per-node=8 test.py`
|
||||||
|
# adjust `--nproc-per-node` to the number of GPUs you want to use.
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
data = torch.FloatTensor([1,] * 128).to(f"cuda:{dist.get_rank()}")
|
||||||
|
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
value = data.mean().item()
|
||||||
|
assert value == dist.get_world_size()
|
||||||
|
|
||||||
|
If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.
|
||||||
75
docs/source/getting_started/tpu-installation.rst
Normal file
75
docs/source/getting_started/tpu-installation.rst
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
.. _installation_tpu:
|
||||||
|
|
||||||
|
Installation with TPU
|
||||||
|
=====================
|
||||||
|
|
||||||
|
vLLM supports Google Cloud TPUs using PyTorch XLA.
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
------------
|
||||||
|
|
||||||
|
* Google Cloud TPU VM (single host)
|
||||||
|
* TPU versions: v5e, v5p, v4
|
||||||
|
* Python: 3.10
|
||||||
|
|
||||||
|
Installation options:
|
||||||
|
|
||||||
|
1. :ref:`Build a docker image with Dockerfile <build_docker_tpu>`.
|
||||||
|
2. :ref:`Build from source <build_from_source_tpu>`.
|
||||||
|
|
||||||
|
.. _build_docker_tpu:
|
||||||
|
|
||||||
|
Build a docker image with :code:`Dockerfile.tpu`
|
||||||
|
------------------------------------------------
|
||||||
|
|
||||||
|
`Dockerfile.tpu <https://github.com/vllm-project/vllm/blob/main/Dockerfile.tpu>`_ is provided to build a docker image with TPU support.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||||
|
|
||||||
|
|
||||||
|
You can run the docker image with the following command:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # Make sure to add `--privileged --net host --shm-size=16G`.
|
||||||
|
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
|
||||||
|
|
||||||
|
|
||||||
|
.. _build_from_source_tpu:
|
||||||
|
|
||||||
|
Build from source
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
You can also build and install the TPU backend from source.
|
||||||
|
|
||||||
|
First, install the dependencies:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # (Recommended) Create a new conda environment.
|
||||||
|
$ conda create -n myenv python=3.10 -y
|
||||||
|
$ conda activate myenv
|
||||||
|
|
||||||
|
$ # Clean up the existing torch and torch-xla packages.
|
||||||
|
$ pip uninstall torch torch-xla -y
|
||||||
|
|
||||||
|
$ # Install PyTorch and PyTorch XLA.
|
||||||
|
$ export DATE="+20240601"
|
||||||
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
$ # Install JAX and Pallas.
|
||||||
|
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
|
|
||||||
|
$ # Install other build dependencies.
|
||||||
|
$ pip install packaging aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
Next, build vLLM from source. This will only take a few seconds:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
||||||
@@ -63,9 +63,11 @@ Documentation
|
|||||||
|
|
||||||
getting_started/installation
|
getting_started/installation
|
||||||
getting_started/amd-installation
|
getting_started/amd-installation
|
||||||
getting_started/neuron-installation
|
|
||||||
getting_started/cpu-installation
|
getting_started/cpu-installation
|
||||||
|
getting_started/neuron-installation
|
||||||
|
getting_started/tpu-installation
|
||||||
getting_started/quickstart
|
getting_started/quickstart
|
||||||
|
getting_started/debugging
|
||||||
getting_started/examples/examples_index
|
getting_started/examples/examples_index
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@@ -88,6 +90,8 @@ Documentation
|
|||||||
models/adding_model
|
models/adding_model
|
||||||
models/engine_args
|
models/engine_args
|
||||||
models/lora
|
models/lora
|
||||||
|
models/vlm
|
||||||
|
models/spec_decode
|
||||||
models/performance
|
models/performance
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@@ -95,21 +99,29 @@ Documentation
|
|||||||
:caption: Quantization
|
:caption: Quantization
|
||||||
|
|
||||||
quantization/auto_awq
|
quantization/auto_awq
|
||||||
|
quantization/fp8
|
||||||
quantization/fp8_e5m2_kvcache
|
quantization/fp8_e5m2_kvcache
|
||||||
quantization/fp8_e4m3_kvcache
|
quantization/fp8_e4m3_kvcache
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 1
|
||||||
|
:caption: Automatic Prefix Caching
|
||||||
|
|
||||||
|
automatic_prefix_caching/apc
|
||||||
|
automatic_prefix_caching/details
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
:caption: Developer Documentation
|
:caption: Developer Documentation
|
||||||
|
|
||||||
dev/sampling_params
|
dev/sampling_params
|
||||||
dev/offline_inference/offline_index
|
dev/offline_inference/offline_index
|
||||||
dev/engine/engine_index
|
dev/engine/engine_index
|
||||||
dev/kernel/paged_attention
|
dev/kernel/paged_attention
|
||||||
|
dev/multimodal/multimodal_index
|
||||||
dev/dockerfile/dockerfile
|
dev/dockerfile/dockerfile
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 1
|
||||||
:caption: Community
|
:caption: Community
|
||||||
|
|
||||||
community/meetups
|
community/meetups
|
||||||
|
|||||||
77
docs/source/models/spec_decode.rst
Normal file
77
docs/source/models/spec_decode.rst
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
.. _spec_decode:
|
||||||
|
|
||||||
|
Speculative decoding in vLLM
|
||||||
|
============================
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Please note that speculative decoding in vLLM is not yet optimized and does
|
||||||
|
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
|
||||||
|
to optimize it is ongoing and can be followed in `this issue. <https://github.com/vllm-project/vllm/issues/4630>`_
|
||||||
|
|
||||||
|
This document shows how to use `Speculative Decoding <https://x.com/karpathy/status/1697318534555336961>`_ with vLLM.
|
||||||
|
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.
|
||||||
|
|
||||||
|
Speculating with a draft model
|
||||||
|
------------------------------
|
||||||
|
|
||||||
|
The following code configures vLLM to use speculative decoding with a draft model, speculating 5 tokens at a time.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/opt-6.7b",
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
speculative_model="facebook/opt-125m",
|
||||||
|
num_speculative_tokens=5,
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
Speculating by matching n-grams in the prompt
|
||||||
|
---------------------------------------------
|
||||||
|
|
||||||
|
The following code configures vLLM to use speculative decoding where proposals are generated by
|
||||||
|
matching n-grams in the prompt. For more information read `this thread. <https://x.com/joao_gante/status/1747322413006643259>`_
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/opt-6.7b",
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
speculative_model="[ngram]",
|
||||||
|
num_speculative_tokens=5,
|
||||||
|
ngram_prompt_lookup_max=4,
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
Resources for vLLM contributors
|
||||||
|
-------------------------------
|
||||||
|
* `A Hacker's Guide to Speculative Decoding in vLLM <https://www.youtube.com/watch?v=9wNAgpX6z_4>`_
|
||||||
|
* `What is Lookahead Scheduling in vLLM? <https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a>`_
|
||||||
|
* `Information on batch expansion. <https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8>`_
|
||||||
|
* `Dynamic speculative decoding <https://github.com/vllm-project/vllm/issues/4565>`_
|
||||||
@@ -87,6 +87,14 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
||||||
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`LlavaForConditionalGeneration`
|
||||||
|
- LLaVA-1.5
|
||||||
|
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||||
|
-
|
||||||
|
* - :code:`LlavaNextForConditionalGeneration`
|
||||||
|
- LLaVA-NeXT
|
||||||
|
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||||
|
-
|
||||||
* - :code:`MiniCPMForCausalLM`
|
* - :code:`MiniCPMForCausalLM`
|
||||||
- MiniCPM
|
- MiniCPM
|
||||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
||||||
|
|||||||
130
docs/source/models/vlm.rst
Normal file
130
docs/source/models/vlm.rst
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
.. _vlm:
|
||||||
|
|
||||||
|
Using VLMs
|
||||||
|
==========
|
||||||
|
|
||||||
|
vLLM provides experimental support for Vision Language Models (VLMs). This document shows you how to run and serve these models using vLLM.
|
||||||
|
|
||||||
|
Engine Arguments
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The following :ref:`engine arguments <engine_args>` are specific to VLMs:
|
||||||
|
|
||||||
|
.. argparse::
|
||||||
|
:module: vllm.engine.arg_utils
|
||||||
|
:func: _vlm_engine_args_parser
|
||||||
|
:prog: -m vllm.entrypoints.openai.api_server
|
||||||
|
:nodefaultconst:
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Currently, the support for vision language models on vLLM has the following limitations:
|
||||||
|
|
||||||
|
* Only single image input is supported per text prompt.
|
||||||
|
* Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means our LLaVA-NeXT output may not exactly match the huggingface implementation.
|
||||||
|
|
||||||
|
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
||||||
|
|
||||||
|
Offline Batched Inference
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
image_input_type="pixel_values",
|
||||||
|
image_token_id=32000,
|
||||||
|
image_input_shape="1,3,336,336",
|
||||||
|
image_feature_size=576,
|
||||||
|
)
|
||||||
|
|
||||||
|
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
|
||||||
|
|
||||||
|
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
|
||||||
|
* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt = "<image>" * 576 + (
|
||||||
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||||
|
|
||||||
|
# Load the image using PIL.Image
|
||||||
|
image = ...
|
||||||
|
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": ImagePixelData(image),
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
|
||||||
|
|
||||||
|
Online OpenAI Vision API Compatible Inference
|
||||||
|
----------------------------------------------
|
||||||
|
|
||||||
|
You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be
|
||||||
|
added in the future.
|
||||||
|
|
||||||
|
Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Since OpenAI Vision API is based on `Chat <https://platform.openai.com/docs/api-reference/chat>`_ API, a chat template
|
||||||
|
is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the
|
||||||
|
HuggingFace Llava chat template that you can find in the example folder `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model llava-hf/llava-1.5-7b-hf \
|
||||||
|
--image-input-type pixel_values \
|
||||||
|
--image-token-id 32000 \
|
||||||
|
--image-input-shape 1,3,336,336 \
|
||||||
|
--image-feature-size 576 \
|
||||||
|
--chat-template template_llava.jinja
|
||||||
|
|
||||||
|
To consume the server, you can use the OpenAI client like in the example below:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
chat_response = client.chat.completions.create(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
print("Chat response:", chat_response)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be
|
||||||
|
processed automatically by the server.
|
||||||
206
docs/source/quantization/fp8.rst
Normal file
206
docs/source/quantization/fp8.rst
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
.. _fp8:
|
||||||
|
|
||||||
|
FP8
|
||||||
|
==================
|
||||||
|
|
||||||
|
vLLM supports FP8 (8-bit floating point) computation using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x. Currently, only Hopper and Ada Lovelace GPUs are supported. Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
|
||||||
|
|
||||||
|
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.
|
||||||
|
|
||||||
|
The FP8 types typically supported in hardware have two distinct representations, each useful in different scenarios:
|
||||||
|
|
||||||
|
- **E4M3**: Consists of 1 sign bit, 4 exponent bits, and 3 bits of mantissa. It can store values up to +/-448 and ``nan``.
|
||||||
|
- **E5M2**: Consists of 1 sign bit, 5 exponent bits, and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf``, and ``nan``. The tradeoff for the increased dynamic range is lower precision of the stored values.
|
||||||
|
|
||||||
|
Quick Start with Online Dynamic Quantization
|
||||||
|
--------------------------------------------
|
||||||
|
|
||||||
|
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying ``--quantization="fp8"`` in the command line or setting ``quantization="fp8"`` in the LLM constructor.
|
||||||
|
|
||||||
|
In this mode, all Linear modules (except for the final ``lm_head``) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
model = LLM("facebook/opt-125m", quantization="fp8")
|
||||||
|
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
|
||||||
|
result = model.generate("Hello, my name is")
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
|
||||||
|
|
||||||
|
Offline Quantization
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
For offline quantization to FP8, please install the `AutoFP8 library <https://github.com/neuralmagic/autofp8>`_.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/neuralmagic/AutoFP8.git
|
||||||
|
pip install -e AutoFP8
|
||||||
|
|
||||||
|
This package introduces the ``AutoFP8ForCausalLM`` and ``BaseQuantizeConfig`` objects for managing how your model will be compressed.
|
||||||
|
|
||||||
|
Offline Quantization with Dynamic Activation Scaling Factors
|
||||||
|
------------------------------------------------------------
|
||||||
|
|
||||||
|
You can use AutoFP8 to produce checkpoints with their weights quantized to FP8 ahead of time and let vLLM handle calculating dynamic scales for the activations at runtime for maximum accuracy. You can enable this with the ``activation_scheme="dynamic"`` argument.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Please note that although this mode doesn't give you better performance, it reduces memory footprint compared to online quantization.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
|
||||||
|
|
||||||
|
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-Dynamic"
|
||||||
|
|
||||||
|
# Define quantization config with static activation scales
|
||||||
|
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="dynamic")
|
||||||
|
# For dynamic activation scales, there is no need for calbration examples
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
# Load the model, quantize, and save checkpoint
|
||||||
|
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
||||||
|
model.quantize(examples)
|
||||||
|
model.save_quantized(quantized_model_dir)
|
||||||
|
|
||||||
|
In the output of the above script, you should be able to see the quantized Linear modules (FP8DynamicLinear) replaced in the model definition.
|
||||||
|
Note that the ``lm_head`` Linear module at the end is currently skipped by default.
|
||||||
|
|
||||||
|
.. code-block:: text
|
||||||
|
|
||||||
|
LlamaForCausalLM(
|
||||||
|
(model): LlamaModel(
|
||||||
|
(embed_tokens): Embedding(128256, 4096)
|
||||||
|
(layers): ModuleList(
|
||||||
|
(0-31): 32 x LlamaDecoderLayer(
|
||||||
|
(self_attn): LlamaSdpaAttention(
|
||||||
|
(q_proj): FP8DynamicLinear()
|
||||||
|
(k_proj): FP8DynamicLinear()
|
||||||
|
(v_proj): FP8DynamicLinear()
|
||||||
|
(o_proj): FP8DynamicLinear()
|
||||||
|
(rotary_emb): LlamaRotaryEmbedding()
|
||||||
|
)
|
||||||
|
(mlp): LlamaMLP(
|
||||||
|
(gate_proj): FP8DynamicLinear()
|
||||||
|
(up_proj): FP8DynamicLinear()
|
||||||
|
(down_proj): FP8DynamicLinear()
|
||||||
|
(act_fn): SiLU()
|
||||||
|
)
|
||||||
|
(input_layernorm): LlamaRMSNorm()
|
||||||
|
(post_attention_layernorm): LlamaRMSNorm()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(norm): LlamaRMSNorm()
|
||||||
|
)
|
||||||
|
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
|
||||||
|
)
|
||||||
|
Saving the model to Meta-Llama-3-8B-Instruct-FP8-Dynamic
|
||||||
|
|
||||||
|
Your model checkpoint with quantized weights should be available at ``Meta-Llama-3-8B-Instruct-FP8/``.
|
||||||
|
We can see that the weights are smaller than the original BF16 precision.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
ls -lh Meta-Llama-3-8B-Instruct-FP8-Dynamic/
|
||||||
|
total 8.5G
|
||||||
|
-rw-rw-r-- 1 user user 869 Jun 7 14:43 config.json
|
||||||
|
-rw-rw-r-- 1 user user 194 Jun 7 14:43 generation_config.json
|
||||||
|
-rw-rw-r-- 1 user user 4.7G Jun 7 14:43 model-00001-of-00002.safetensors
|
||||||
|
-rw-rw-r-- 1 user user 3.9G Jun 7 14:43 model-00002-of-00002.safetensors
|
||||||
|
-rw-rw-r-- 1 user user 43K Jun 7 14:43 model.safetensors.index.json
|
||||||
|
-rw-rw-r-- 1 user user 296 Jun 7 14:43 special_tokens_map.json
|
||||||
|
-rw-rw-r-- 1 user user 50K Jun 7 14:43 tokenizer_config.json
|
||||||
|
-rw-rw-r-- 1 user user 8.7M Jun 7 14:43 tokenizer.json
|
||||||
|
|
||||||
|
Finally, you can load the quantized model checkpoint directly in vLLM.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8-Dynamic/")
|
||||||
|
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
|
||||||
|
result = model.generate("Hello, my name is")
|
||||||
|
|
||||||
|
Offline Quantization with Static Activation Scaling Factors
|
||||||
|
-----------------------------------------------------------
|
||||||
|
|
||||||
|
For the best inference performance, you can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
|
||||||
|
|
||||||
|
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# Load and tokenize 512 dataset samples for calibration of activation scales
|
||||||
|
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
|
||||||
|
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
|
||||||
|
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
# Define quantization config with static activation scales
|
||||||
|
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
|
||||||
|
|
||||||
|
# Load the model, quantize, and save checkpoint
|
||||||
|
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
|
||||||
|
model.quantize(examples)
|
||||||
|
model.save_quantized(quantized_model_dir)
|
||||||
|
|
||||||
|
Your model checkpoint with quantized weights and activations should be available at ``Meta-Llama-3-8B-Instruct-FP8/``.
|
||||||
|
Finally, you can load the quantized model checkpoint directly in vLLM.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/")
|
||||||
|
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
|
||||||
|
result = model.generate("Hello, my name is")
|
||||||
|
|
||||||
|
FP8 checkpoint structure explanation
|
||||||
|
-----------------------------------------------------------
|
||||||
|
|
||||||
|
Here we detail the structure for the FP8 checkpoints.
|
||||||
|
|
||||||
|
The following is necessary to be present in the model's ``config.json``:
|
||||||
|
|
||||||
|
.. code-block:: text
|
||||||
|
|
||||||
|
"quantization_config": {
|
||||||
|
"quant_method": "fp8",
|
||||||
|
"activation_scheme": "static" or "dynamic"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Each quantized layer in the state_dict will have these tensors:
|
||||||
|
|
||||||
|
* If the config has ``"activation_scheme": "static"``:
|
||||||
|
|
||||||
|
.. code-block:: text
|
||||||
|
|
||||||
|
model.layers.0.mlp.down_proj.weight < F8_E4M3
|
||||||
|
model.layers.0.mlp.down_proj.input_scale < F32
|
||||||
|
model.layers.0.mlp.down_proj.weight_scale < F32
|
||||||
|
|
||||||
|
* If the config has ``"activation_scheme": "dynamic"``:
|
||||||
|
|
||||||
|
.. code-block:: text
|
||||||
|
|
||||||
|
model.layers.0.mlp.down_proj.weight < F8_E4M3
|
||||||
|
model.layers.0.mlp.down_proj.weight_scale < F32
|
||||||
|
|
||||||
|
|
||||||
|
Additionally, there can be `FP8 kv-cache scaling factors <https://github.com/vllm-project/vllm/pull/4893>`_ contained within quantized checkpoints specified through the ``.kv_scale`` parameter present on the Attention Module, such as:
|
||||||
|
|
||||||
|
.. code-block:: text
|
||||||
|
|
||||||
|
model.layers.0.self_attn.kv_scale < F32
|
||||||
@@ -3,11 +3,9 @@
|
|||||||
Distributed Inference and Serving
|
Distributed Inference and Serving
|
||||||
=================================
|
=================================
|
||||||
|
|
||||||
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
|
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
|
||||||
|
|
||||||
.. code-block:: console
|
Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.
|
||||||
|
|
||||||
$ pip install ray
|
|
||||||
|
|
||||||
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
|
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
|
||||||
|
|
||||||
@@ -25,10 +23,12 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
|
|||||||
$ --model facebook/opt-13b \
|
$ --model facebook/opt-13b \
|
||||||
$ --tensor-parallel-size 4
|
$ --tensor-parallel-size 4
|
||||||
|
|
||||||
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install ray
|
||||||
|
|
||||||
$ # On head node
|
$ # On head node
|
||||||
$ ray start --head
|
$ ray start --head
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-refer
|
|||||||
- Chat: `tools`, and `tool_choice`.
|
- Chat: `tools`, and `tool_choice`.
|
||||||
- Completions: `suffix`.
|
- Completions: `suffix`.
|
||||||
|
|
||||||
|
vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst).
|
||||||
|
|
||||||
## Extra Parameters
|
## Extra Parameters
|
||||||
vLLM supports a set of parameters that are not part of the OpenAI API.
|
vLLM supports a set of parameters that are not part of the OpenAI API.
|
||||||
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
||||||
@@ -109,4 +111,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
|||||||
:module: vllm.entrypoints.openai.cli_args
|
:module: vllm.entrypoints.openai.cli_args
|
||||||
:func: make_arg_parser
|
:func: make_arg_parser
|
||||||
:prog: -m vllm.entrypoints.openai.api_server
|
:prog: -m vllm.entrypoints.openai.api_server
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Tool calling in the chat completion API
|
||||||
|
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
|
||||||
|
|
||||||
|
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
|
||||||
|
|
||||||
|
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
|
||||||
|
|
||||||
|
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||||
|
|
||||||
|
Please refer to the OpenAI API reference documentation for more information.
|
||||||
|
|||||||
@@ -3,33 +3,36 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.sequence import MultiModalData
|
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||||
|
|
||||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||||
|
# You can use `.buildkite/download-images.sh` to download them
|
||||||
|
|
||||||
|
|
||||||
def run_llava_pixel_values():
|
def run_llava_pixel_values(*, disable_image_processor: bool = False):
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="llava-hf/llava-1.5-7b-hf",
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
image_input_type="pixel_values",
|
image_input_type="pixel_values",
|
||||||
image_token_id=32000,
|
image_token_id=32000,
|
||||||
image_input_shape="1,3,336,336",
|
image_input_shape="1,3,336,336",
|
||||||
image_feature_size=576,
|
image_feature_size=576,
|
||||||
|
disable_image_processor=disable_image_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "<image>" * 576 + (
|
prompt = "<image>" * 576 + (
|
||||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||||
|
|
||||||
# This should be provided by another online or offline component.
|
if disable_image_processor:
|
||||||
image = torch.load("images/stop_sign_pixel_values.pt")
|
image = torch.load("images/stop_sign_pixel_values.pt")
|
||||||
|
else:
|
||||||
|
image = Image.open("images/stop_sign.jpg")
|
||||||
|
|
||||||
outputs = llm.generate({
|
outputs = llm.generate({
|
||||||
"prompt":
|
"prompt": prompt,
|
||||||
prompt,
|
"multi_modal_data": ImagePixelData(image),
|
||||||
"multi_modal_data":
|
|
||||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
@@ -49,15 +52,13 @@ def run_llava_image_features():
|
|||||||
prompt = "<image>" * 576 + (
|
prompt = "<image>" * 576 + (
|
||||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||||
|
|
||||||
# This should be provided by another online or offline component.
|
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
|
||||||
image = torch.load("images/stop_sign_image_features.pt")
|
|
||||||
|
|
||||||
outputs = llm.generate({
|
outputs = llm.generate({
|
||||||
"prompt":
|
"prompt": prompt,
|
||||||
prompt,
|
"multi_modal_data": ImageFeatureData(image),
|
||||||
"multi_modal_data":
|
|
||||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|||||||
140
examples/lora_with_quantization_inference.py
Normal file
140
examples/lora_with_quantization_inference.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
This example shows how to use LoRA with different quantization techniques
|
||||||
|
for offline inference.
|
||||||
|
|
||||||
|
Requires HuggingFace credentials for access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_prompts(
|
||||||
|
lora_path: str
|
||||||
|
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||||
|
return [
|
||||||
|
# this is an example of using quantization without LoRA
|
||||||
|
("My name is",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128), None),
|
||||||
|
# the next three examples use quantization with LoRA
|
||||||
|
("my name is",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128),
|
||||||
|
LoRARequest("lora-test-1", 1, lora_path)),
|
||||||
|
("The capital of USA is",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128),
|
||||||
|
LoRARequest("lora-test-2", 1, lora_path)),
|
||||||
|
("The capital of France is",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128),
|
||||||
|
LoRARequest("lora-test-3", 1, lora_path)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_requests(engine: LLMEngine,
|
||||||
|
test_prompts: List[Tuple[str, SamplingParams,
|
||||||
|
Optional[LoRARequest]]]):
|
||||||
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
|
request_id = 0
|
||||||
|
|
||||||
|
while test_prompts or engine.has_unfinished_requests():
|
||||||
|
if test_prompts:
|
||||||
|
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||||
|
engine.add_request(str(request_id),
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=lora_request)
|
||||||
|
request_id += 1
|
||||||
|
|
||||||
|
request_outputs: List[RequestOutput] = engine.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
print("----------------------------------------------------")
|
||||||
|
print(f"Prompt: {request_output.prompt}")
|
||||||
|
print(f"Output: {request_output.outputs[0].text}")
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_engine(model: str, quantization: str,
|
||||||
|
lora_repo: Optional[str]) -> LLMEngine:
|
||||||
|
"""Initialize the LLMEngine."""
|
||||||
|
|
||||||
|
if quantization == "bitsandbytes":
|
||||||
|
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
|
||||||
|
# It quantizes the model when loading, with some config info from the
|
||||||
|
# LoRA adapter repo. So need to set the parameter of load_format and
|
||||||
|
# qlora_adapter_name_or_path as below.
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model,
|
||||||
|
quantization=quantization,
|
||||||
|
qlora_adapter_name_or_path=lora_repo,
|
||||||
|
load_format="bitsandbytes",
|
||||||
|
enable_lora=True,
|
||||||
|
max_lora_rank=64,
|
||||||
|
# set it only in GPUs of limited memory
|
||||||
|
enforce_eager=True)
|
||||||
|
else:
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model,
|
||||||
|
quantization=quantization,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4,
|
||||||
|
# set it only in GPUs of limited memory
|
||||||
|
enforce_eager=True)
|
||||||
|
return LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function that sets up and runs the prompt processing."""
|
||||||
|
|
||||||
|
test_configs = [{
|
||||||
|
"name": "qlora_inference_example",
|
||||||
|
'model': "huggyllama/llama-7b",
|
||||||
|
'quantization': "bitsandbytes",
|
||||||
|
'lora_repo': 'timdettmers/qlora-flan-7b'
|
||||||
|
}, {
|
||||||
|
"name": "AWQ_inference_with_lora_example",
|
||||||
|
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
|
||||||
|
'quantization': "awq",
|
||||||
|
'lora_repo': 'jashing/tinyllama-colorist-lora'
|
||||||
|
}, {
|
||||||
|
"name": "GPTQ_inference_with_lora_example",
|
||||||
|
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
|
||||||
|
'quantization': "gptq",
|
||||||
|
'lora_repo': 'jashing/tinyllama-colorist-lora'
|
||||||
|
}]
|
||||||
|
|
||||||
|
for test_config in test_configs:
|
||||||
|
print(
|
||||||
|
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
|
||||||
|
)
|
||||||
|
engine = initialize_engine(test_config['model'],
|
||||||
|
test_config['quantization'],
|
||||||
|
test_config['lora_repo'])
|
||||||
|
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
|
||||||
|
test_prompts = create_test_prompts(lora_path)
|
||||||
|
process_requests(engine, test_prompts)
|
||||||
|
|
||||||
|
# Clean up the GPU memory for the next test
|
||||||
|
del engine
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,5 +1,8 @@
|
|||||||
|
from time import time
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Common prefix.
|
||||||
prefix = (
|
prefix = (
|
||||||
"You are an expert school principal, skilled in effectively managing "
|
"You are an expert school principal, skilled in effectively managing "
|
||||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||||
@@ -18,36 +21,62 @@ prompts = [
|
|||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
generating_prompts = [prefix + prompt for prompt in prompts]
|
||||||
|
|
||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.0)
|
sampling_params = SamplingParams(temperature=0.0)
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True)
|
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
|
||||||
|
|
||||||
generating_prompts = [prefix + prompt for prompt in prompts]
|
prefix_cached_llm = LLM(model="facebook/opt-125m",
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
gpu_memory_utilization=0.4)
|
||||||
|
print("Results without `enable_prefix_caching`")
|
||||||
|
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(generating_prompts, sampling_params)
|
start_time_regular = time()
|
||||||
|
outputs = regular_llm.generate(generating_prompts, sampling_params)
|
||||||
|
duration_regular = time() - start_time_regular
|
||||||
|
|
||||||
|
regular_generated_texts = []
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
|
regular_generated_texts.append(generated_text)
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
# The llm.generate call will batch all prompts and send the batch at once
|
# Warmup so that the shared prompt's KV cache is computed.
|
||||||
# if resources allow. The prefix will only be cached after the first batch
|
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
|
||||||
# is processed, so we need to call generate once to calculate the prefix
|
|
||||||
# and cache it.
|
|
||||||
outputs = llm.generate(generating_prompts[0], sampling_params)
|
|
||||||
|
|
||||||
# Subsequent batches can leverage the cached prefix
|
# Generate with prefix caching.
|
||||||
outputs = llm.generate(generating_prompts, sampling_params)
|
start_time_cached = time()
|
||||||
|
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
|
||||||
|
duration_cached = time() - start_time_cached
|
||||||
|
|
||||||
# Print the outputs. You should see the same outputs as before
|
print("Results with `enable_prefix_caching`")
|
||||||
|
|
||||||
|
cached_generated_texts = []
|
||||||
|
# Print the outputs. You should see the same outputs as before.
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
|
cached_generated_texts.append(generated_text)
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# Compare the results and display the speedup
|
||||||
|
generated_same = all([
|
||||||
|
regular_generated_texts[i] == cached_generated_texts[i]
|
||||||
|
for i in range(len(prompts))
|
||||||
|
])
|
||||||
|
print(f"Generated answers are the same: {generated_same}")
|
||||||
|
|
||||||
|
speedup = round(duration_regular / duration_cached, 2)
|
||||||
|
print(f"Speed up of cached generation compared to the regular is: {speedup}")
|
||||||
|
|||||||
23
examples/template_llava.jinja
Normal file
23
examples/template_llava.jinja
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{%- if messages[0]['role'] == 'system' -%}
|
||||||
|
{%- set system_message = messages[0]['content'] -%}
|
||||||
|
{%- set messages = messages[1:] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{% set system_message = '' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{{ bos_token + system_message }}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||||
|
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if message['role'] == 'user' -%}
|
||||||
|
{{ 'USER: ' + message['content'] + '\n' }}
|
||||||
|
{%- elif message['role'] == 'assistant' -%}
|
||||||
|
{{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{ 'ASSISTANT:' }}
|
||||||
|
{% endif %}
|
||||||
@@ -3,18 +3,12 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from tensorizer import stream_io
|
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.distributed import (init_distributed_environment,
|
|
||||||
initialize_model_parallel)
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
|
||||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
|
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
|
||||||
TensorizerConfig,
|
TensorizerConfig,
|
||||||
serialize_vllm_model)
|
tensorize_vllm_model)
|
||||||
|
|
||||||
# yapf conflicts with isort for this docstring
|
# yapf conflicts with isort for this docstring
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@@ -61,6 +55,12 @@ Which downloads the model tensors from your S3 bucket and deserializes them.
|
|||||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||||
they were serialized with encryption.
|
they were serialized with encryption.
|
||||||
|
|
||||||
|
To support distributed tensor-parallel models, each model shard will be
|
||||||
|
serialized to a separate file. The tensorizer_uri is then specified as a string
|
||||||
|
template with a format specifier such as '%03d' that will be rendered with the
|
||||||
|
shard's rank. Sharded models serialized with this script will be named as
|
||||||
|
model-rank-%03d.tensors
|
||||||
|
|
||||||
For more information on the available arguments for serializing, run
|
For more information on the available arguments for serializing, run
|
||||||
`python -m examples.tensorize_vllm_model serialize --help`.
|
`python -m examples.tensorize_vllm_model serialize --help`.
|
||||||
|
|
||||||
@@ -168,77 +168,72 @@ def parse_args():
|
|||||||
def deserialize():
|
def deserialize():
|
||||||
llm = LLM(model=args.model,
|
llm = LLM(model=args.model,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
model_loader_extra_config=tensorizer_config
|
model_loader_extra_config=tensorizer_config
|
||||||
)
|
)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
args = parse_args()
|
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
|
||||||
|
or os.environ.get("S3_ACCESS_KEY_ID", None))
|
||||||
|
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
|
||||||
|
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
|
||||||
|
s3_endpoint = (getattr(args, 's3_endpoint', None)
|
||||||
|
or os.environ.get("S3_ENDPOINT_URL", None))
|
||||||
|
|
||||||
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
|
credentials = {
|
||||||
or os.environ.get("S3_ACCESS_KEY_ID", None))
|
"s3_access_key_id": s3_access_key_id,
|
||||||
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
|
"s3_secret_access_key": s3_secret_access_key,
|
||||||
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
|
"s3_endpoint": s3_endpoint
|
||||||
s3_endpoint = (getattr(args, 's3_endpoint', None)
|
}
|
||||||
or os.environ.get("S3_ENDPOINT_URL", None))
|
|
||||||
|
|
||||||
credentials = {
|
model_ref = args.model
|
||||||
"s3_access_key_id": s3_access_key_id,
|
|
||||||
"s3_secret_access_key": s3_secret_access_key,
|
|
||||||
"s3_endpoint": s3_endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
_read_stream, _write_stream = (partial(
|
model_name = model_ref.split("/")[1]
|
||||||
stream_io.open_stream,
|
|
||||||
mode=mode,
|
|
||||||
s3_access_key_id=s3_access_key_id,
|
|
||||||
s3_secret_access_key=s3_secret_access_key,
|
|
||||||
s3_endpoint=s3_endpoint,
|
|
||||||
) for mode in ("rb", "wb+"))
|
|
||||||
|
|
||||||
model_ref = args.model
|
keyfile = args.keyfile if args.keyfile else None
|
||||||
|
|
||||||
model_name = model_ref.split("/")[1]
|
if args.model_loader_extra_config:
|
||||||
|
config = json.loads(args.model_loader_extra_config)
|
||||||
|
tensorizer_args = \
|
||||||
|
TensorizerConfig(**config)._construct_tensorizer_args()
|
||||||
|
tensorizer_args.tensorizer_uri = args.path_to_tensors
|
||||||
|
else:
|
||||||
|
tensorizer_args = None
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
if args.command == "serialize":
|
||||||
os.environ["MASTER_PORT"] = "8080"
|
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||||
|
dataclasses.fields(EngineArgs)}
|
||||||
|
|
||||||
init_distributed_environment(world_size=1, rank=0, local_rank=0)
|
engine_args = EngineArgs.from_cli_args(
|
||||||
initialize_model_parallel()
|
argparse.Namespace(**eng_args_dict)
|
||||||
|
|
||||||
keyfile = args.keyfile if args.keyfile else None
|
|
||||||
|
|
||||||
|
|
||||||
if args.model_loader_extra_config:
|
|
||||||
config = json.loads(args.model_loader_extra_config)
|
|
||||||
tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
|
|
||||||
tensorizer_args.tensorizer_uri = args.path_to_tensors
|
|
||||||
else:
|
|
||||||
tensorizer_args = None
|
|
||||||
|
|
||||||
if args.command == "serialize":
|
|
||||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
|
||||||
dataclasses.fields(EngineArgs)}
|
|
||||||
|
|
||||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
|
||||||
|
|
||||||
input_dir = args.serialized_directory.rstrip('/')
|
|
||||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
|
||||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
|
||||||
model_path = f"{base_path}/model.tensors"
|
|
||||||
tensorizer_config = TensorizerConfig(
|
|
||||||
tensorizer_uri=model_path,
|
|
||||||
**credentials)
|
|
||||||
serialize_vllm_model(engine, tensorizer_config, keyfile)
|
|
||||||
elif args.command == "deserialize":
|
|
||||||
if not tensorizer_args:
|
|
||||||
tensorizer_config = TensorizerConfig(
|
|
||||||
tensorizer_uri=args.path_to_tensors,
|
|
||||||
encryption_keyfile = keyfile,
|
|
||||||
**credentials
|
|
||||||
)
|
)
|
||||||
deserialize()
|
|
||||||
else:
|
input_dir = args.serialized_directory.rstrip('/')
|
||||||
raise ValueError("Either serialize or deserialize must be specified.")
|
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||||
|
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||||
|
if engine_args.tensor_parallel_size > 1:
|
||||||
|
model_path = f"{base_path}/model-rank-%03d.tensors"
|
||||||
|
else:
|
||||||
|
model_path = f"{base_path}/model.tensors"
|
||||||
|
|
||||||
|
tensorizer_config = TensorizerConfig(
|
||||||
|
tensorizer_uri=model_path,
|
||||||
|
encryption_keyfile=keyfile,
|
||||||
|
**credentials)
|
||||||
|
|
||||||
|
tensorize_vllm_model(engine_args, tensorizer_config)
|
||||||
|
|
||||||
|
elif args.command == "deserialize":
|
||||||
|
if not tensorizer_args:
|
||||||
|
tensorizer_config = TensorizerConfig(
|
||||||
|
tensorizer_uri=args.path_to_tensors,
|
||||||
|
encryption_keyfile = keyfile,
|
||||||
|
**credentials
|
||||||
|
)
|
||||||
|
deserialize()
|
||||||
|
else:
|
||||||
|
raise ValueError("Either serialize or deserialize must be specified.")
|
||||||
|
|||||||
15
format.sh
15
format.sh
@@ -36,12 +36,12 @@ tool_version_check() {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-lint.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-lint.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-lint.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-lint.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-lint.txt | cut -d'=' -f3)"
|
||||||
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
|
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-lint.txt | cut -d'=' -f3)"
|
||||||
|
|
||||||
YAPF_FLAGS=(
|
YAPF_FLAGS=(
|
||||||
'--recursive'
|
'--recursive'
|
||||||
@@ -101,6 +101,7 @@ mypy vllm/core --config-file pyproject.toml
|
|||||||
mypy vllm/distributed --config-file pyproject.toml
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
mypy vllm/entrypoints --config-file pyproject.toml
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
mypy vllm/executor --config-file pyproject.toml
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
|
mypy vllm/multimodal --config-file pyproject.toml
|
||||||
mypy vllm/usage --config-file pyproject.toml
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
mypy vllm/*.py --config-file pyproject.toml
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
@@ -117,7 +118,7 @@ mypy vllm/model_executor --config-file pyproject.toml
|
|||||||
# https://github.com/codespell-project/codespell/issues/1915
|
# https://github.com/codespell-project/codespell/issues/1915
|
||||||
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
|
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
|
||||||
CODESPELL_EXCLUDES=(
|
CODESPELL_EXCLUDES=(
|
||||||
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,tests/lora/data/**,build/**'
|
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
|
||||||
)
|
)
|
||||||
|
|
||||||
# check spelling of specified files
|
# check spelling of specified files
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ exclude = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "dout, te, indicies"
|
ignore-words-list = "dout, te, indicies, subtile"
|
||||||
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
@@ -71,4 +71,5 @@ markers = [
|
|||||||
"skip_global_cleanup",
|
"skip_global_cleanup",
|
||||||
"llm: run tests for vLLM API only",
|
"llm: run tests for vLLM API only",
|
||||||
"openai: run tests for OpenAI API only",
|
"openai: run tests for OpenAI API only",
|
||||||
|
"llava: run tests for LLaVA models only",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,10 +12,11 @@ aiohttp
|
|||||||
openai
|
openai
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
|
pillow # Required for image processing
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
prometheus-fastapi-instrumentator >= 7.0.0
|
prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer == 0.10.1
|
lm-format-enforcer == 0.10.1
|
||||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
outlines >= 0.0.43 # Requires torch >= 2.1.0
|
||||||
typing_extensions
|
typing_extensions
|
||||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user