Compare commits
103 Commits
v0.5.2
...
v0.5.3.pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38c4b7e863 | ||
|
|
a112a84aad | ||
|
|
461089a21a | ||
|
|
71950af726 | ||
|
|
cb1362a889 | ||
|
|
bb2fc08072 | ||
|
|
3eda4ec780 | ||
|
|
22fa2e35cb | ||
|
|
c5201240a4 | ||
|
|
97234be0ec | ||
|
|
c051bfe4eb | ||
|
|
9e0b558a09 | ||
|
|
e519ae097a | ||
|
|
7c2749a4fd | ||
|
|
729171ae58 | ||
|
|
c5e8330997 | ||
|
|
e0c15758b8 | ||
|
|
bdf5fd1386 | ||
|
|
5a96ee52a3 | ||
|
|
42c7f66a38 | ||
|
|
69d5ae38dc | ||
|
|
fea59c7712 | ||
|
|
739b61a348 | ||
|
|
89c1c6a196 | ||
|
|
42de2cefcb | ||
|
|
c9eef37f32 | ||
|
|
396d92d5e0 | ||
|
|
25e778aa16 | ||
|
|
b6df37f943 | ||
|
|
14f91fe67c | ||
|
|
d7f4178dd9 | ||
|
|
082ecd80d5 | ||
|
|
f952bbc8ff | ||
|
|
9364f74eee | ||
|
|
06d6c5fe9f | ||
|
|
683e3cb9c4 | ||
|
|
9042d68362 | ||
|
|
3f8d42c81f | ||
|
|
7bd82002ae | ||
|
|
2e26564259 | ||
|
|
e81522e879 | ||
|
|
45ceb85a0c | ||
|
|
4cc24f01b1 | ||
|
|
07eb6f19f3 | ||
|
|
f0bbfaf917 | ||
|
|
30efe41532 | ||
|
|
9ed82e7074 | ||
|
|
51f8aa90ad | ||
|
|
a5314e8698 | ||
|
|
a921e86392 | ||
|
|
6366efc67b | ||
|
|
dbe5588554 | ||
|
|
d4201e06d5 | ||
|
|
b5672a112c | ||
|
|
c5df56f88b | ||
|
|
1689219ebf | ||
|
|
4ffffccb7e | ||
|
|
f53b8f0d05 | ||
|
|
2d4733ba2d | ||
|
|
15c6a079b1 | ||
|
|
ecdb462c24 | ||
|
|
58ca663224 | ||
|
|
4634c8728b | ||
|
|
c8a7d51c49 | ||
|
|
e2fbaee725 | ||
|
|
8a74c68bd1 | ||
|
|
61e592747c | ||
|
|
d25877dd9b | ||
|
|
1c27d25fb5 | ||
|
|
18fecc3559 | ||
|
|
b5af8c223c | ||
|
|
b5241e41d9 | ||
|
|
e76466dde2 | ||
|
|
5f0b9933e6 | ||
|
|
a38524f338 | ||
|
|
2fa4623d9e | ||
|
|
a9a2e74d21 | ||
|
|
e09ce759aa | ||
|
|
5fa6e9876e | ||
|
|
5bf35a91e4 | ||
|
|
a19e8d3726 | ||
|
|
10383887e0 | ||
|
|
1d094fd7c0 | ||
|
|
ce37be7ba0 | ||
|
|
7f62077af5 | ||
|
|
09c2eb85dd | ||
|
|
978aed5300 | ||
|
|
160e1d8c99 | ||
|
|
94162beb9f | ||
|
|
c467dff24f | ||
|
|
9f4ccec761 | ||
|
|
38ef94888a | ||
|
|
2bb0489cb3 | ||
|
|
7508a3dc34 | ||
|
|
7a3d2a5b95 | ||
|
|
d97011512e | ||
|
|
37d776606f | ||
|
|
d92b3c5cde | ||
|
|
9ad32dacd9 | ||
|
|
d6f3b3d5c4 | ||
|
|
4552e37b55 | ||
|
|
ec9933f4a5 | ||
|
|
3dee97b05f |
@@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -ex
|
|
||||||
set -o pipefail
|
|
||||||
|
|
||||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
|
||||||
|
|
||||||
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
|
|
||||||
mkdir -p images
|
|
||||||
cd images
|
|
||||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
|
|
||||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg
|
|
||||||
|
|
||||||
cd -
|
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.905
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.905
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.752
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.754
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.753
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.753
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.758
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.759
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
|
||||||
Meta-Llama-3-70B-Instruct.yaml
|
Meta-Llama-3-70B-Instruct.yaml
|
||||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||||
Qwen2-57B-A14-Instruct.yaml
|
Qwen2-57B-A14-Instruct.yaml
|
||||||
|
|||||||
@@ -2,4 +2,6 @@ Meta-Llama-3-8B-Instruct.yaml
|
|||||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||||
|
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||||
|
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
|
|||||||
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
lm_eval --model vllm \
|
lm_eval --model vllm \
|
||||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
|
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
|
||||||
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
||||||
--batch_size $BATCH_SIZE
|
--batch_size $BATCH_SIZE
|
||||||
|
|||||||
@@ -3,13 +3,15 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue
|
queue: cpu_queue
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
# rename the files to change linux -> manylinux1
|
# rename the files to change linux -> manylinux1
|
||||||
- "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
|
- "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
|
||||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
matrix:
|
matrix:
|
||||||
setup:
|
setup:
|
||||||
cuda_version:
|
cuda_version:
|
||||||
|
|||||||
@@ -66,11 +66,18 @@ trap remove_docker_container EXIT
|
|||||||
|
|
||||||
echo "--- Running container"
|
echo "--- Running container"
|
||||||
|
|
||||||
|
HF_CACHE="$(realpath ~)/huggingface"
|
||||||
|
mkdir -p ${HF_CACHE}
|
||||||
|
HF_MOUNT="/root/.cache/huggingface"
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
--device /dev/kfd --device /dev/dri \
|
--device /dev/kfd --device /dev/dri \
|
||||||
--network host \
|
--network host \
|
||||||
|
--shm-size=16gb \
|
||||||
--rm \
|
--rm \
|
||||||
-e HF_TOKEN \
|
-e HF_TOKEN \
|
||||||
|
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||||
|
-e HF_HOME=${HF_MOUNT} \
|
||||||
--name ${container_name} \
|
--name ${container_name} \
|
||||||
${image_name} \
|
${image_name} \
|
||||||
/bin/bash -c "${@}"
|
/bin/bash -c "${@}"
|
||||||
|
|||||||
16
.buildkite/run-tpu-test.sh
Normal file
16
.buildkite/run-tpu-test.sh
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
set -e
|
||||||
|
|
||||||
|
# Build the docker image.
|
||||||
|
docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||||
|
|
||||||
|
# Set up cleanup.
|
||||||
|
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
# Remove the container that might not be cleaned up in the previous run.
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# For HF_TOKEN.
|
||||||
|
source /etc/environment
|
||||||
|
# Run a simple end-to-end example.
|
||||||
|
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
|
||||||
|
python3 /workspace/vllm/examples/offline_inference_tpu.py
|
||||||
@@ -12,7 +12,6 @@ steps:
|
|||||||
fast_check_only: true
|
fast_check_only: true
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s async_engine # Async Engine
|
- pytest -v -s async_engine # Async Engine
|
||||||
- bash ../.buildkite/download-images.sh # Inputs
|
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s multimodal
|
- pytest -v -s multimodal
|
||||||
- pytest -v -s test_utils.py # Utils
|
- pytest -v -s test_utils.py # Utils
|
||||||
@@ -22,7 +21,7 @@ steps:
|
|||||||
fast_check: true
|
fast_check: true
|
||||||
fast_check_only: true
|
fast_check_only: true
|
||||||
commands:
|
commands:
|
||||||
- apt-get install curl libsodium23 && pytest -v -s tensorizer_loader # Tensorizer
|
- apt-get install -y curl libsodium23 && pytest -v -s tensorizer_loader # Tensorizer
|
||||||
- pytest -v -s metrics # Metrics
|
- pytest -v -s metrics # Metrics
|
||||||
- "pip install \
|
- "pip install \
|
||||||
opentelemetry-sdk \
|
opentelemetry-sdk \
|
||||||
@@ -45,8 +44,10 @@ steps:
|
|||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
fast_check: true
|
fast_check: true
|
||||||
commands:
|
commands:
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
||||||
|
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
|
||||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||||
@@ -54,7 +55,7 @@ steps:
|
|||||||
- label: Core Test
|
- label: Core Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
fast_check: true
|
fast_check: true
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s core
|
- pytest -v -s core
|
||||||
- pytest -v -s distributed/test_parallel_state.py
|
- pytest -v -s distributed/test_parallel_state.py
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||||
|
|
||||||
@@ -82,10 +83,11 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||||
@@ -110,6 +112,7 @@ steps:
|
|||||||
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
||||||
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||||
|
|
||||||
@@ -117,16 +120,11 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
commands:
|
commands:
|
||||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
|
||||||
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
|
||||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
|
||||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
|
||||||
|
|
||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||||
# OOM in the CI unless we run this separately
|
# OOM in the CI unless we run this separately
|
||||||
- pytest -v -s tokenization
|
- pytest -v -s tokenization
|
||||||
@@ -147,6 +145,7 @@ steps:
|
|||||||
# install tensorizer for tensorize_vllm_model.py
|
# install tensorizer for tensorize_vllm_model.py
|
||||||
- pip install awscli tensorizer
|
- pip install awscli tensorizer
|
||||||
- python3 offline_inference.py
|
- python3 offline_inference.py
|
||||||
|
- python3 cpu_offload.py
|
||||||
- python3 offline_inference_with_prefix.py
|
- python3 offline_inference_with_prefix.py
|
||||||
- python3 llm_engine_example.py
|
- python3 llm_engine_example.py
|
||||||
- python3 llava_example.py
|
- python3 llava_example.py
|
||||||
@@ -155,7 +154,6 @@ steps:
|
|||||||
- label: Inputs Test
|
- label: Inputs Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s multimodal
|
- pytest -v -s multimodal
|
||||||
|
|
||||||
@@ -175,7 +173,6 @@ steps:
|
|||||||
- label: Vision Language Models Test
|
- label: Vision Language Models Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
|
||||||
- pytest -v -s models -m vlm
|
- pytest -v -s models -m vlm
|
||||||
|
|
||||||
- label: Prefix Caching Test
|
- label: Prefix Caching Test
|
||||||
@@ -225,7 +222,7 @@ steps:
|
|||||||
- label: Tensorizer Test
|
- label: Tensorizer Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- apt-get install curl libsodium23
|
- apt-get install -y curl libsodium23
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s tensorizer_loader
|
- pytest -v -s tensorizer_loader
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
owner: context.repo.owner,
|
owner: context.repo.owner,
|
||||||
repo: context.repo.repo,
|
repo: context.repo.repo,
|
||||||
issue_number: context.issue.number,
|
issue_number: context.issue.number,
|
||||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger `fastcheck` CI to run, which consists only a small and essential subset of tests to quickly catch errors with the flexibility to run extra individual tests on top (you can do this by unblocking test steps in the Buildkite run). \n\nFull CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well.\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||||
})
|
})
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
|||||||
# versions are derived from Dockerfile.rocm
|
# versions are derived from Dockerfile.rocm
|
||||||
#
|
#
|
||||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
|
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
|
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Try to find python package with an executable that exactly matches
|
# Try to find python package with an executable that exactly matches
|
||||||
@@ -101,7 +101,7 @@ elseif(HIP_FOUND)
|
|||||||
# ROCm 5.X and 6.X
|
# ROCm 5.X and 6.X
|
||||||
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
|
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
|
||||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
|
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
|
||||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
|
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
|
||||||
"expected for ROCm build, saw ${Torch_VERSION} instead.")
|
"expected for ROCm build, saw ${Torch_VERSION} instead.")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
@@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
"csrc/moe_align_block_size_kernels.cu"
|
||||||
|
"csrc/prepare_inputs/advance_step.cu"
|
||||||
"csrc/torch_bindings.cpp")
|
"csrc/torch_bindings.cpp")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
@@ -171,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
|
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
|
|||||||
42
Dockerfile
42
Dockerfile
@@ -8,10 +8,10 @@
|
|||||||
ARG CUDA_VERSION=12.4.1
|
ARG CUDA_VERSION=12.4.1
|
||||||
#################### BASE BUILD IMAGE ####################
|
#################### BASE BUILD IMAGE ####################
|
||||||
# prepare basic build environment
|
# prepare basic build environment
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
|
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||||
|
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG CUDA_VERSION=12.4.1
|
||||||
ARG PYTHON_VERSION=3
|
ARG PYTHON_VERSION=3.10
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
@@ -21,13 +21,16 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
|||||||
&& apt-get install -y ccache software-properties-common \
|
&& apt-get install -y ccache software-properties-common \
|
||||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv python3-pip \
|
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
||||||
&& python3 --version \
|
&& python3 --version
|
||||||
&& python3 -m pip --version
|
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y python3-pip git curl sudo
|
&& apt-get install -y git curl sudo
|
||||||
|
|
||||||
|
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||||
|
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||||
|
RUN python3 -m pip --version
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
@@ -58,7 +61,7 @@ ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
|||||||
#################### WHEEL BUILD IMAGE ####################
|
#################### WHEEL BUILD IMAGE ####################
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3
|
ARG PYTHON_VERSION=3.10
|
||||||
|
|
||||||
# install build dependencies
|
# install build dependencies
|
||||||
COPY requirements-build.txt requirements-build.txt
|
COPY requirements-build.txt requirements-build.txt
|
||||||
@@ -100,7 +103,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
&& tar -xzf sccache.tar.gz \
|
&& tar -xzf sccache.tar.gz \
|
||||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||||
&& export SCCACHE_BUCKET=vllm-build-sccache \
|
&& if [ "$CUDA_VERSION" = "11.8.0" ]; then \
|
||||||
|
export SCCACHE_BUCKET=vllm-build-sccache-2; \
|
||||||
|
else \
|
||||||
|
export SCCACHE_BUCKET=vllm-build-sccache; \
|
||||||
|
fi \
|
||||||
&& export SCCACHE_REGION=us-west-2 \
|
&& export SCCACHE_REGION=us-west-2 \
|
||||||
&& export CMAKE_BUILD_TYPE=Release \
|
&& export CMAKE_BUILD_TYPE=Release \
|
||||||
&& sccache --show-stats \
|
&& sccache --show-stats \
|
||||||
@@ -149,12 +156,27 @@ RUN pip --verbose wheel -r requirements-mamba.txt \
|
|||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
# image with vLLM installed
|
# image with vLLM installed
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
|
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG CUDA_VERSION=12.4.1
|
||||||
|
ARG PYTHON_VERSION=3.10
|
||||||
WORKDIR /vllm-workspace
|
WORKDIR /vllm-workspace
|
||||||
|
|
||||||
|
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||||
|
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||||
|
&& apt-get update -y \
|
||||||
|
&& apt-get install -y ccache software-properties-common \
|
||||||
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
|
&& apt-get update -y \
|
||||||
|
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||||
|
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
||||||
|
&& python3 --version
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y python3-pip git vim
|
&& apt-get install -y python3-pip git vim curl libibverbs-dev
|
||||||
|
|
||||||
|
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||||
|
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||||
|
RUN python3 -m pip --version
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||||
# to run the OpenAI compatible server.
|
# to run the OpenAI compatible server.
|
||||||
|
|
||||||
FROM ubuntu:22.04 AS dev
|
FROM ubuntu:20.04 AS dev
|
||||||
|
|
||||||
RUN apt-get update -y && \
|
RUN apt-get update -y && \
|
||||||
apt-get install -y python3-pip git
|
apt-get install -y python3-pip git
|
||||||
|
|||||||
@@ -1,26 +1,24 @@
|
|||||||
# Default ROCm 6.1 base image
|
# Default ROCm 6.1 base image
|
||||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
||||||
|
|
||||||
# Tested and supported base rocm/pytorch images
|
|
||||||
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
|
|
||||||
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
|
|
||||||
ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
|
||||||
|
|
||||||
# Default ROCm ARCHes to build vLLM for.
|
# Default ROCm ARCHes to build vLLM for.
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||||
|
|
||||||
# Whether to build CK-based flash-attention
|
# Whether to install CK-based flash-attention
|
||||||
# If 0, will not build flash attention
|
# If 0, will not install flash-attention
|
||||||
# This is useful for gfx target where flash-attention is not supported
|
|
||||||
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
|
|
||||||
# Triton FA is used by default on ROCm now so this is unnecessary.
|
|
||||||
ARG BUILD_FA="1"
|
ARG BUILD_FA="1"
|
||||||
|
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
|
||||||
|
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
|
||||||
|
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
|
||||||
|
# architectures specified in `FA_GFX_ARCHS`
|
||||||
|
ARG TRY_FA_WHEEL="1"
|
||||||
|
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
|
||||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||||
ARG FA_BRANCH="ae7928c"
|
ARG FA_BRANCH="23a2b1c2"
|
||||||
|
|
||||||
# Whether to build triton on rocm
|
# Whether to build triton on rocm
|
||||||
ARG BUILD_TRITON="1"
|
ARG BUILD_TRITON="1"
|
||||||
ARG TRITON_BRANCH="0ef1848"
|
ARG TRITON_BRANCH="e0fc12c"
|
||||||
|
|
||||||
### Base image build stage
|
### Base image build stage
|
||||||
FROM $BASE_IMAGE AS base
|
FROM $BASE_IMAGE AS base
|
||||||
@@ -48,27 +46,15 @@ RUN apt-get update && apt-get install -y \
|
|||||||
ARG APP_MOUNT=/vllm-workspace
|
ARG APP_MOUNT=/vllm-workspace
|
||||||
WORKDIR ${APP_MOUNT}
|
WORKDIR ${APP_MOUNT}
|
||||||
|
|
||||||
RUN pip install --upgrade pip
|
RUN python3 -m pip install --upgrade pip
|
||||||
# Remove sccache so it doesn't interfere with ccache
|
# Remove sccache so it doesn't interfere with ccache
|
||||||
# TODO: implement sccache support across components
|
# TODO: implement sccache support across components
|
||||||
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
|
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||||
# Install torch == 2.5.0 on ROCm
|
# Install torch == 2.5.0 on ROCm
|
||||||
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||||
*"rocm-5.7"*) \
|
|
||||||
pip uninstall -y torch torchaudio torchvision \
|
|
||||||
&& pip install --no-cache-dir --pre \
|
|
||||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
|
||||||
torchvision==0.20.0.dev20240710 \
|
|
||||||
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
|
|
||||||
*"rocm-6.0"*) \
|
|
||||||
pip uninstall -y torch torchaudio torchvision \
|
|
||||||
&& pip install --no-cache-dir --pre \
|
|
||||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
|
||||||
torchvision==0.20.0.dev20240710 \
|
|
||||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
|
|
||||||
*"rocm-6.1"*) \
|
*"rocm-6.1"*) \
|
||||||
pip uninstall -y torch torchaudio torchvision \
|
python3 -m pip uninstall -y torch torchaudio torchvision \
|
||||||
&& pip install --no-cache-dir --pre \
|
&& python3 -m pip install --no-cache-dir --pre \
|
||||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||||
torchvision==0.20.0.dev20240710 \
|
torchvision==0.20.0.dev20240710 \
|
||||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
||||||
@@ -87,29 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
|
|||||||
FROM base AS build_amdsmi
|
FROM base AS build_amdsmi
|
||||||
# Build amdsmi wheel always
|
# Build amdsmi wheel always
|
||||||
RUN cd /opt/rocm/share/amd_smi \
|
RUN cd /opt/rocm/share/amd_smi \
|
||||||
&& pip wheel . --wheel-dir=/install
|
&& python3 -m pip wheel . --wheel-dir=/install
|
||||||
|
|
||||||
|
|
||||||
### Flash-Attention wheel build stage
|
### Flash-Attention wheel build stage
|
||||||
FROM base AS build_fa
|
FROM base AS build_fa
|
||||||
ARG BUILD_FA
|
ARG BUILD_FA
|
||||||
|
ARG TRY_FA_WHEEL
|
||||||
|
ARG FA_WHEEL_URL
|
||||||
ARG FA_GFX_ARCHS
|
ARG FA_GFX_ARCHS
|
||||||
ARG FA_BRANCH
|
ARG FA_BRANCH
|
||||||
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||||
if [ "$BUILD_FA" = "1" ]; then \
|
if [ "$BUILD_FA" = "1" ]; then \
|
||||||
mkdir -p libs \
|
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
|
||||||
&& cd libs \
|
# If a suitable wheel exists, we download it instead of building FA
|
||||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
|
||||||
&& cd flash-attention \
|
else \
|
||||||
&& git checkout "${FA_BRANCH}" \
|
mkdir -p libs \
|
||||||
&& git submodule update --init \
|
&& cd libs \
|
||||||
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||||
*"rocm-5.7"*) \
|
&& cd flash-attention \
|
||||||
export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
|
&& git checkout "${FA_BRANCH}" \
|
||||||
&& patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
|
&& git submodule update --init \
|
||||||
*) ;; esac \
|
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
fi; \
|
||||||
# Create an empty directory otherwise as later build stages expect one
|
# Create an empty directory otherwise as later build stages expect one
|
||||||
else mkdir -p /install; \
|
else mkdir -p /install; \
|
||||||
fi
|
fi
|
||||||
@@ -148,7 +136,7 @@ RUN case "$(which python3)" in \
|
|||||||
|
|
||||||
# Package upgrades for useful functionality or to avoid dependency issues
|
# Package upgrades for useful functionality or to avoid dependency issues
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install --upgrade numba scipy huggingface-hub[cli]
|
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
|
||||||
|
|
||||||
# Make sure punica kernels are built (for LoRA)
|
# Make sure punica kernels are built (for LoRA)
|
||||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
@@ -159,14 +147,11 @@ ENV TOKENIZERS_PARALLELISM=false
|
|||||||
|
|
||||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -U -r requirements-rocm.txt \
|
python3 -m pip install -Ur requirements-rocm.txt \
|
||||||
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||||
*"rocm-6.0"*) \
|
|
||||||
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
|
|
||||||
*"rocm-6.1"*) \
|
*"rocm-6.1"*) \
|
||||||
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
|
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
|
||||||
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
|
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \
|
||||||
&& cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
|
|
||||||
# Prevent interference if torch bundles its own HIP runtime
|
# Prevent interference if torch bundles its own HIP runtime
|
||||||
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
|
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
|
||||||
*) ;; esac \
|
*) ;; esac \
|
||||||
@@ -178,7 +163,7 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
|||||||
mkdir -p libs \
|
mkdir -p libs \
|
||||||
&& cp /install/*.whl libs \
|
&& cp /install/*.whl libs \
|
||||||
# Preemptively uninstall to avoid same-version no-installs
|
# Preemptively uninstall to avoid same-version no-installs
|
||||||
&& pip uninstall -y amdsmi;
|
&& python3 -m pip uninstall -y amdsmi;
|
||||||
|
|
||||||
# Copy triton wheel(s) into final image if they were built
|
# Copy triton wheel(s) into final image if they were built
|
||||||
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||||
@@ -186,7 +171,7 @@ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
|||||||
&& if ls /install/*.whl; then \
|
&& if ls /install/*.whl; then \
|
||||||
cp /install/*.whl libs \
|
cp /install/*.whl libs \
|
||||||
# Preemptively uninstall to avoid same-version no-installs
|
# Preemptively uninstall to avoid same-version no-installs
|
||||||
&& pip uninstall -y triton; fi
|
&& python3 -m pip uninstall -y triton; fi
|
||||||
|
|
||||||
# Copy flash-attn wheel(s) into final image if they were built
|
# Copy flash-attn wheel(s) into final image if they were built
|
||||||
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||||
@@ -194,11 +179,11 @@ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
|||||||
&& if ls /install/*.whl; then \
|
&& if ls /install/*.whl; then \
|
||||||
cp /install/*.whl libs \
|
cp /install/*.whl libs \
|
||||||
# Preemptively uninstall to avoid same-version no-installs
|
# Preemptively uninstall to avoid same-version no-installs
|
||||||
&& pip uninstall -y flash-attn; fi
|
&& python3 -m pip uninstall -y flash-attn; fi
|
||||||
|
|
||||||
# Install wheels that were built to the final image
|
# Install wheels that were built to the final image
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
if ls libs/*.whl; then \
|
if ls libs/*.whl; then \
|
||||||
pip install libs/*.whl; fi
|
python3 -m pip install libs/*.whl; fi
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
ARG NIGHTLY_DATE="20240601"
|
ARG NIGHTLY_DATE="20240713"
|
||||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
@@ -6,6 +6,8 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
# Install aiohttp separately to avoid build errors.
|
# Install aiohttp separately to avoid build errors.
|
||||||
RUN pip install aiohttp
|
RUN pip install aiohttp
|
||||||
|
# Install NumPy 1 instead of NumPy 2.
|
||||||
|
RUN pip install "numpy<2"
|
||||||
# Install the TPU and Pallas dependencies.
|
# Install the TPU and Pallas dependencies.
|
||||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
@@ -15,9 +17,4 @@ COPY . /workspace/vllm
|
|||||||
ENV VLLM_TARGET_DEVICE="tpu"
|
ENV VLLM_TARGET_DEVICE="tpu"
|
||||||
RUN cd /workspace/vllm && python setup.py develop
|
RUN cd /workspace/vllm && python setup.py develop
|
||||||
|
|
||||||
# Re-install outlines to avoid dependency errors.
|
|
||||||
# The outlines version must follow requirements-common.txt.
|
|
||||||
RUN pip uninstall outlines -y
|
|
||||||
RUN pip install "outlines>=0.0.43"
|
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04
|
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04
|
||||||
|
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||||
|
|||||||
11
README.md
11
README.md
@@ -16,7 +16,17 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
**The Fifth vLLM Bay Area Meetup (July 24th 5pm-8pm PT)**
|
||||||
|
|
||||||
|
We are excited to announce our fifth vLLM Meetup!
|
||||||
|
Join us to hear the vLLM's recent updates and the upcoming roadmap.
|
||||||
|
Additionally, our collaborators from AWS will be presenting their insights and experiences in deploying vLLM.
|
||||||
|
Register now [here](https://lu.ma/lp0gyjqr) and be part of the event!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||||
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
@@ -90,6 +100,7 @@ vLLM is a community project. Our compute resources for development and testing a
|
|||||||
- Databricks
|
- Databricks
|
||||||
- DeepInfra
|
- DeepInfra
|
||||||
- Dropbox
|
- Dropbox
|
||||||
|
- Google Cloud
|
||||||
- Lambda Lab
|
- Lambda Lab
|
||||||
- NVIDIA
|
- NVIDIA
|
||||||
- Replicate
|
- Replicate
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.inputs import PromptStrictInputs
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
|
|||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
size=(args.batch_size,
|
size=(args.batch_size,
|
||||||
args.input_len))
|
args.input_len))
|
||||||
dummy_inputs: List[PromptStrictInputs] = [{
|
dummy_inputs: List[PromptInputs] = [{
|
||||||
"prompt_token_ids": batch
|
"prompt_token_ids": batch
|
||||||
} for batch in dummy_prompt_token_ids.tolist()]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
|
|||||||
@@ -20,18 +20,18 @@ DEFAULT_TP_SIZES = [1]
|
|||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
def to_fp8(tensor: torch.tensor) -> torch.tensor:
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
return torch.round(tensor.clamp(
|
return torch.round(tensor.clamp(
|
||||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
def to_int8(tensor: torch.tensor) -> torch.tensor:
|
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
k: int) -> Tuple[torch.tensor, torch.tensor]:
|
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
a = torch.randn((m, k), device='cuda') * 5
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
b = torch.randn((n, k), device='cuda').t() * 5
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
@@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
|||||||
# impl
|
# impl
|
||||||
|
|
||||||
|
|
||||||
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||||
scale_b: torch.tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: torch.dtype) -> torch.tensor:
|
out_dtype: torch.dtype) -> torch.Tensor:
|
||||||
return torch.mm(a, b)
|
return torch.mm(a, b)
|
||||||
|
|
||||||
|
|
||||||
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||||
scale_b: torch.tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: torch.dtype) -> torch.tensor:
|
out_dtype: torch.dtype) -> torch.Tensor:
|
||||||
return torch._scaled_mm(a,
|
return torch._scaled_mm(a,
|
||||||
b,
|
b,
|
||||||
scale_a=scale_a,
|
scale_a=scale_a,
|
||||||
@@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
|||||||
out_dtype=out_dtype)
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
|
||||||
scale_a: torch.tensor, scale_b: torch.tensor,
|
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||||
out_dtype: torch.dtype) -> torch.tensor:
|
out_dtype: torch.dtype) -> torch.Tensor:
|
||||||
return torch._scaled_mm(a,
|
return torch._scaled_mm(a,
|
||||||
b,
|
b,
|
||||||
scale_a=scale_a,
|
scale_a=scale_a,
|
||||||
@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
|||||||
use_fast_accum=True)
|
use_fast_accum=True)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||||
scale_b: torch.tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: torch.dtype) -> torch.tensor:
|
out_dtype: torch.dtype) -> torch.Tensor:
|
||||||
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||||
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
|
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
|
||||||
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||||
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ def main(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
@@ -117,7 +117,8 @@ def main(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@@ -136,7 +137,8 @@ def main(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
|
|||||||
@@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
const float k_scale, const float v_scale, const int tp_rank,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int partition_idx = blockIdx.z;
|
const int partition_idx = blockIdx.z;
|
||||||
const int max_num_partitions = gridDim.z;
|
const int max_num_partitions = gridDim.z;
|
||||||
@@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
|
|||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||||
k_vec_quant, kv_scale);
|
k_vec_quant, k_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
|
|||||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
// Vector conversion from V_quant_vec to V_vec.
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||||
kv_scale);
|
v_scale);
|
||||||
}
|
}
|
||||||
if (block_idx == num_seq_blocks - 1) {
|
if (block_idx == num_seq_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||||
@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
const float k_scale, const float v_scale, const int tp_rank,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
KV_DTYPE, IS_BLOCK_SPARSE>(
|
KV_DTYPE, IS_BLOCK_SPARSE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||||
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||||
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
|
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride, blocksparse_block_size,
|
blocksparse_vert_stride, blocksparse_block_size,
|
||||||
blocksparse_head_sliding_step);
|
blocksparse_head_sliding_step);
|
||||||
}
|
}
|
||||||
@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
const float k_scale, const float v_scale, const int tp_rank,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||||
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
|
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
||||||
blocksparse_head_sliding_step);
|
blocksparse_head_sliding_step);
|
||||||
}
|
}
|
||||||
@@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||||
kv_scale, tp_rank, blocksparse_local_blocks, \
|
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||||
blocksparse_vert_stride, blocksparse_block_size, \
|
blocksparse_vert_stride, blocksparse_block_size, \
|
||||||
blocksparse_head_sliding_step);
|
blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
|
|||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||||
const int tp_rank, const int blocksparse_local_blocks,
|
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_head_sliding_step) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
@@ -770,7 +770,7 @@ void paged_attention_v1_launcher(
|
|||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
||||||
IS_BLOCK_SPARSE>( \
|
IS_BLOCK_SPARSE>( \
|
||||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
|
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step);
|
blocksparse_block_size, blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@@ -815,8 +815,8 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
@@ -833,7 +833,7 @@ void paged_attention_v1(
|
|||||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
|
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step); \
|
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||||
@@ -850,8 +850,8 @@ void paged_attention_v2_launcher(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||||
const int tp_rank, const int blocksparse_local_blocks,
|
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_head_sliding_step) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
@@ -932,8 +932,9 @@ void paged_attention_v2_launcher(
|
|||||||
IS_BLOCK_SPARSE>( \
|
IS_BLOCK_SPARSE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||||
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
|
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step);
|
blocksparse_vert_stride, blocksparse_block_size, \
|
||||||
|
blocksparse_head_sliding_step);
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||||
switch (is_block_sparse) { \
|
switch (is_block_sparse) { \
|
||||||
@@ -980,8 +981,8 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, const double k_scale,
|
||||||
const double kv_scale);
|
const double v_scale);
|
||||||
|
|
||||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
|
|||||||
@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
// block_size]
|
// block_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride, const int value_stride, const int num_heads,
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
const int head_size, const int block_size, const int x,
|
const int head_size, const int block_size, const int x, const float k_scale,
|
||||||
const float kv_scale) {
|
const float v_scale) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
if (slot_idx < 0) {
|
if (slot_idx < 0) {
|
||||||
@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
value_cache[tgt_value_idx] = tgt_value;
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
} else {
|
} else {
|
||||||
key_cache[tgt_key_idx] =
|
key_cache[tgt_key_idx] =
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||||
value_cache[tgt_value_idx] =
|
value_cache[tgt_value_idx] =
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||||
num_heads, head_size, block_size, x, kv_scale);
|
num_heads, head_size, block_size, x, k_scale, v_scale);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
@@ -258,7 +258,8 @@ void reshape_and_cache(
|
|||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype, const double kv_scale) {
|
const std::string& kv_cache_dtype, const double k_scale,
|
||||||
|
const double v_scale) {
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
@@ -318,13 +319,13 @@ namespace vllm {
|
|||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||||
Tout* __restrict__ dst_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
const float kv_scale,
|
const float scale,
|
||||||
const int64_t block_stride) {
|
const int64_t block_stride) {
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
dst_cache[idx] =
|
dst_cache[idx] =
|
||||||
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
|||||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
|
||||||
|
|
||||||
// Only for testing.
|
// Only for testing.
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const double kv_scale, const std::string& kv_cache_dtype) {
|
const double scale, const std::string& kv_cache_dtype) {
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
|
|||||||
@@ -423,11 +423,11 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
"CPU backend does not support blocksparse attention yet.");
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
@@ -742,11 +742,11 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
"CPU backend does not support blocksparse attention yet.");
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
|
|||||||
@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype, double kv_scale) {
|
const std::string& kv_cache_dtype, double k_scale,
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
double v_scale) {
|
||||||
|
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||||
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||||
@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||||
@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor! key_cache, Tensor! value_cache,"
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
" Tensor slot_mapping,"
|
" Tensor slot_mapping,"
|
||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" float kv_scale) -> ()");
|
" float k_scale, float v_scale) -> ()");
|
||||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
35
csrc/ops.h
35
csrc/ops.h
@@ -8,8 +8,8 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@@ -19,8 +19,8 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
|||||||
|
|
||||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
|
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||||
|
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
@@ -84,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
int64_t size_k);
|
int64_t size_k);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& perm, torch::Tensor& workspace,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_k, bool is_k_full);
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full, bool has_zp);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||||
|
int64_t size_n, int64_t num_bits);
|
||||||
|
|
||||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
@@ -123,12 +132,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
|||||||
|
|
||||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||||
|
|
||||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor& scale);
|
torch::Tensor const& scale);
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
|
void dynamic_per_token_scaled_fp8_quant(
|
||||||
|
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||||
|
c10::optional<torch::Tensor> const& scale_ub);
|
||||||
|
|
||||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
|
|||||||
131
csrc/prepare_inputs/advance_step.cu
Normal file
131
csrc/prepare_inputs/advance_step.cu
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
/*
|
||||||
|
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
||||||
|
* PR: https://github.com/vllm-project/vllm/pull/6338
|
||||||
|
* Current restrictions:
|
||||||
|
* 1. Specialized for DraftModelRunner
|
||||||
|
* 2. Supports flash_attn only
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "advance_step.cuh"
|
||||||
|
|
||||||
|
namespace prepare_inputs {
|
||||||
|
|
||||||
|
//
|
||||||
|
template <int const num_threads>
|
||||||
|
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
||||||
|
int block_size, long* input_tokens_ptr,
|
||||||
|
long const* sampled_token_ids_ptr,
|
||||||
|
long* input_positions_ptr,
|
||||||
|
int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||||
|
int const* block_tables_ptr,
|
||||||
|
int64_t const block_tables_stride) {
|
||||||
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
|
if (blockIdx.x >= num_query_blocks) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||||
|
|
||||||
|
if (cur_query_id >= num_queries) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update input_tokens
|
||||||
|
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||||
|
|
||||||
|
int seq_len = seq_lens_ptr[cur_query_id];
|
||||||
|
int next_seq_len = seq_len + 1;
|
||||||
|
int next_input_pos = next_seq_len - 1;
|
||||||
|
|
||||||
|
// Update seq_lens
|
||||||
|
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||||
|
// Update input_positions
|
||||||
|
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||||
|
|
||||||
|
int const* seq_block_tables_ptr =
|
||||||
|
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||||
|
|
||||||
|
int block_index = next_input_pos / block_size;
|
||||||
|
int block_offset = next_input_pos % block_size;
|
||||||
|
|
||||||
|
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||||
|
// Update slot_mapping
|
||||||
|
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
||||||
|
int64_t const size_0, int64_t const size_1,
|
||||||
|
c10::ScalarType const type) {
|
||||||
|
bool size_0_cond = true;
|
||||||
|
if (size_0 != -1) {
|
||||||
|
size_0_cond = t.size(0) == size_0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool size_1_cond = true;
|
||||||
|
if (size_1 != -1) {
|
||||||
|
size_1_cond = t.size(1) == size_1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_contiguous = t.is_contiguous();
|
||||||
|
bool same_type = t.dtype() == type;
|
||||||
|
|
||||||
|
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||||
|
if (!pass) {
|
||||||
|
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||||
|
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||||
|
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||||
|
"], type = ", type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void advance_step(int num_seqs, int num_queries, int block_size,
|
||||||
|
torch::Tensor& input_tokens, // type: long
|
||||||
|
torch::Tensor& sampled_token_ids, // type: long
|
||||||
|
torch::Tensor& input_positions, // type: long
|
||||||
|
torch::Tensor& seq_lens, // type: int
|
||||||
|
torch::Tensor& slot_mapping, // type: long
|
||||||
|
torch::Tensor& block_tables) { // type: int
|
||||||
|
|
||||||
|
if (logging) {
|
||||||
|
printf("advance_step:\n");
|
||||||
|
printf(" num_seqs = %d\n", num_seqs);
|
||||||
|
printf(" num_queries = %d\n", num_queries);
|
||||||
|
printf(" block_size = %d\n", block_size);
|
||||||
|
}
|
||||||
|
// Verify all tensors
|
||||||
|
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||||
|
at::kLong);
|
||||||
|
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||||
|
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||||
|
|
||||||
|
int dev = sampled_token_ids.get_device();
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
|
||||||
|
int blocks;
|
||||||
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
|
||||||
|
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
||||||
|
num_seqs, num_queries, block_size,
|
||||||
|
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||||
|
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||||
|
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||||
|
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||||
|
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||||
|
block_tables.stride(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace prepare_inputs
|
||||||
|
|
||||||
|
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||||
|
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||||
|
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
||||||
|
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
||||||
|
sampled_token_ids, input_positions, seq_lens,
|
||||||
|
slot_mapping, block_tables);
|
||||||
|
}
|
||||||
19
csrc/prepare_inputs/advance_step.cuh
Normal file
19
csrc/prepare_inputs/advance_step.cuh
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace prepare_inputs {
|
||||||
|
|
||||||
|
static constexpr int max_threads = 256;
|
||||||
|
static constexpr bool logging = false;
|
||||||
|
|
||||||
|
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
} // namespace prepare_inputs
|
||||||
@@ -7,6 +7,8 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
#include "../../reduction_utils.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
|
|
||||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <bool is_scale_inverted>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||||
const scalar_t val, const float inverted_scale) {
|
float const val, float const scale) {
|
||||||
float x = static_cast<float>(val) * inverted_scale;
|
float x = 0.0f;
|
||||||
|
if constexpr (is_scale_inverted) {
|
||||||
|
x = val * scale;
|
||||||
|
} else {
|
||||||
|
x = val / scale;
|
||||||
|
}
|
||||||
|
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
}
|
}
|
||||||
@@ -87,6 +95,70 @@ typedef struct __align__(4) {
|
|||||||
}
|
}
|
||||||
float8x4_t;
|
float8x4_t;
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||||
|
int64_t const num_elems, int const tid,
|
||||||
|
int const step) {
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vectorized_in =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||||
|
|
||||||
|
int64_t const num_vec_elems = num_elems >> 2;
|
||||||
|
float absmax_val = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||||
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||||
|
absmax_val = max(absmax_val, fabs(input[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return absmax_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, bool is_scale_inverted>
|
||||||
|
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
|
scalar_t const* __restrict__ input,
|
||||||
|
float const scale,
|
||||||
|
int64_t const num_elems,
|
||||||
|
int const tid, int const step) {
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vectorized_in =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||||
|
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||||
|
|
||||||
|
int64_t const num_vec_elems = num_elems >> 2;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||||
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
|
float8x4_t out_vec;
|
||||||
|
|
||||||
|
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.x), scale);
|
||||||
|
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.y), scale);
|
||||||
|
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.z), scale);
|
||||||
|
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.w), scale);
|
||||||
|
vectorized_out[i] = out_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||||
|
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(input[i]), scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
|||||||
// Invert the scale so that we can use multiplications to avoid expensive
|
// Invert the scale so that we can use multiplications to avoid expensive
|
||||||
// division.
|
// division.
|
||||||
const float inverted_scale = 1.0f / (*scale);
|
const float inverted_scale = 1.0f / (*scale);
|
||||||
|
scaled_fp8_conversion_vec<scalar_t, true>(
|
||||||
|
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||||
|
}
|
||||||
|
|
||||||
// Vectorized input/output to better utilize memory bandwidth.
|
template <typename scalar_t>
|
||||||
const vec4_t<scalar_t>* vectorized_in =
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||||
reinterpret_cast<const vec4_t<scalar_t>*>(input);
|
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
||||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||||
|
const int hidden_size) {
|
||||||
|
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||||
|
|
||||||
int num_vec_elems = num_elems >> 2;
|
int const tid = threadIdx.x;
|
||||||
|
int const token_idx = blockIdx.x;
|
||||||
|
|
||||||
#pragma unroll 4
|
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||||
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
|
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
||||||
float8x4_t out_vec;
|
|
||||||
|
|
||||||
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
|
bool const can_vectorize = hidden_size % 4 == 0;
|
||||||
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
|
|
||||||
vectorized_out[i] = out_vec;
|
float absmax_val = 0.0f;
|
||||||
|
if (can_vectorize) {
|
||||||
|
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
||||||
|
} else {
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
float const x = static_cast<float>(token_input[i]);
|
||||||
|
absmax_val = max(absmax_val, fabs(x));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the remaining elements if num_elems is not divisible by 4
|
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||||
for (int i = num_vec_elems * 4 + tid; i < num_elems;
|
__shared__ float token_scale;
|
||||||
i += blockDim.x * gridDim.x) {
|
if (tid == 0) {
|
||||||
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
|
if (scale_ub) {
|
||||||
|
token_scale = min(block_absmax_val_maybe, *scale_ub);
|
||||||
|
} else {
|
||||||
|
token_scale = block_absmax_val_maybe;
|
||||||
|
}
|
||||||
|
// token scale computation
|
||||||
|
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
||||||
|
scale[token_idx] = token_scale;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Note that we don't use inverted scales so we can match FBGemm impl.
|
||||||
|
if (can_vectorize) {
|
||||||
|
scaled_fp8_conversion_vec<scalar_t, false>(
|
||||||
|
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||||
|
} else {
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
token_output[i] = scaled_fp8_conversion<false>(
|
||||||
|
static_cast<float>(token_input[i]), token_scale);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor const& input, // [..., d]
|
||||||
torch::Tensor& scale) // [1]
|
torch::Tensor const& scale) // [1]
|
||||||
{
|
{
|
||||||
int64_t num_tokens = input.numel() / input.size(-1);
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
int64_t num_elems = input.numel();
|
int64_t num_elems = input.numel();
|
||||||
@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor const& input, // [..., d]
|
||||||
torch::Tensor& scale) // [1]
|
torch::Tensor& scale) // [1]
|
||||||
{
|
{
|
||||||
int64_t num_tokens = input.numel() / input.size(-1);
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
int64_t num_elems = input.numel();
|
int64_t num_elems = input.numel();
|
||||||
@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
scale.data_ptr<float>(), num_elems);
|
scale.data_ptr<float>(), num_elems);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dynamic_per_token_scaled_fp8_quant(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor const& input, // [..., d]
|
||||||
|
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
|
||||||
|
int const hidden_size = input.size(-1);
|
||||||
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
dim3 const grid(num_tokens);
|
||||||
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||||
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,10 +19,10 @@
|
|||||||
* Adapted from https://github.com/IST-DASLab/marlin
|
* Adapted from https://github.com/IST-DASLab/marlin
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "../gptq_marlin/gptq_marlin.cuh"
|
#include "../gptq_marlin/marlin.cuh"
|
||||||
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
|
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||||
|
|
||||||
using namespace gptq_marlin;
|
using namespace marlin;
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
static_assert(std::is_same<scalar_t, half>::value || \
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
", size_k = ", size_k);
|
", size_k = ", size_k);
|
||||||
|
|
||||||
// Verify B
|
// Verify B
|
||||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
int actual_size_n =
|
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
|
||||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||||
", actual_size_n = ", actual_size_n);
|
", actual_size_n = ", actual_size_n);
|
||||||
|
|
||||||
@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
num_groups = b_scales.size(0);
|
num_groups = b_scales.size(0);
|
||||||
|
|
||||||
// Verify workspace size
|
// Verify workspace size
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||||
int min_workspace_size =
|
|
||||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
|
||||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
"workspace.numel = ", workspace.numel(),
|
"workspace.numel = ", workspace.numel(),
|
||||||
" is below min_workspace_size = ", min_workspace_size);
|
" is below min_workspace_size = ", min_workspace_size);
|
||||||
@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
gptq_marlin::max_par);
|
marlin::max_par);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
||||||
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
||||||
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
gptq_marlin::max_par);
|
marlin::max_par);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
|||||||
269
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
269
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
#include "marlin.cuh"
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
|
__global__ void awq_marlin_repack_kernel(
|
||||||
|
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
|
int size_k, int size_n) {}
|
||||||
|
|
||||||
|
} // namespace marlin
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
|
int64_t size_k, int64_t size_n,
|
||||||
|
int64_t num_bits) {
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||||
|
return torch::empty({1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
|
template <int const num_threads, int const num_bits>
|
||||||
|
__global__ void awq_marlin_repack_kernel(
|
||||||
|
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
|
int size_k, int size_n) {
|
||||||
|
constexpr int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
int k_tiles = size_k / tile_k_size;
|
||||||
|
int n_tiles = size_n / tile_n_size;
|
||||||
|
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||||
|
|
||||||
|
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||||
|
if (start_k_tile >= k_tiles) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||||
|
|
||||||
|
// Wait until the next thread tile has been loaded to shared memory.
|
||||||
|
auto wait_for_stage = [&]() {
|
||||||
|
// We only have `stages - 2` active fetches since we are double buffering
|
||||||
|
// and can only issue the next fetch when it is guaranteed that the previous
|
||||||
|
// shared memory load is fully complete (as it may otherwise be
|
||||||
|
// overwritten).
|
||||||
|
cp_async_wait<repack_stages - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
extern __shared__ int4 sh[];
|
||||||
|
|
||||||
|
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||||
|
|
||||||
|
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||||
|
constexpr int stage_k_threads = tile_k_size;
|
||||||
|
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||||
|
|
||||||
|
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||||
|
if (n_tile_id >= n_tiles) {
|
||||||
|
cp_async_fence();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int first_n = n_tile_id * tile_n_size;
|
||||||
|
int first_n_packed = first_n / pack_factor;
|
||||||
|
|
||||||
|
int4* sh_ptr = sh + stage_size * pipe;
|
||||||
|
|
||||||
|
if (threadIdx.x < stage_size) {
|
||||||
|
int k_id = threadIdx.x / stage_n_threads;
|
||||||
|
int n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
|
int first_k = k_tile_id * tile_k_size;
|
||||||
|
|
||||||
|
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||||
|
reinterpret_cast<int4 const*>(
|
||||||
|
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
||||||
|
first_n_packed + (n_id * 4)])));
|
||||||
|
}
|
||||||
|
|
||||||
|
cp_async_fence();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||||
|
if (n_tile_id >= n_tiles) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int warp_id = threadIdx.x / 32;
|
||||||
|
int th_id = threadIdx.x % 32;
|
||||||
|
|
||||||
|
if (warp_id >= 4) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tc_col = th_id / 4;
|
||||||
|
int tc_row = (th_id % 4) * 2;
|
||||||
|
|
||||||
|
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||||
|
|
||||||
|
int cur_n = warp_id * 16 + tc_col;
|
||||||
|
int cur_n_packed = cur_n / pack_factor;
|
||||||
|
int cur_n_pos = cur_n % pack_factor;
|
||||||
|
|
||||||
|
constexpr int sh_stride = tile_n_ints;
|
||||||
|
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||||
|
|
||||||
|
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||||
|
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||||
|
|
||||||
|
// Undo interleaving
|
||||||
|
int cur_n_pos_unpacked;
|
||||||
|
if constexpr (num_bits == 4) {
|
||||||
|
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||||
|
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||||
|
} else {
|
||||||
|
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||||
|
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t vals[8];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
int cur_elem = tc_row + tc_offsets[i];
|
||||||
|
|
||||||
|
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||||
|
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||||
|
sh_stride * cur_elem];
|
||||||
|
|
||||||
|
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||||
|
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||||
|
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||||
|
|
||||||
|
// Result of:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
if constexpr (num_bits == 4) {
|
||||||
|
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||||
|
|
||||||
|
uint32_t res = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
res |= vals[pack_idx[i]] << (i * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||||
|
|
||||||
|
uint32_t res1 = 0;
|
||||||
|
uint32_t res2 = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||||
|
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||||
|
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||||
|
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_stage();
|
||||||
|
};
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||||
|
int n_tile_id = 0;
|
||||||
|
|
||||||
|
start_pipes(k_tile_id, n_tile_id);
|
||||||
|
|
||||||
|
while (n_tile_id < n_tiles) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||||
|
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||||
|
n_tile_id + pipe + repack_stages - 1);
|
||||||
|
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
|
wait_for_stage();
|
||||||
|
}
|
||||||
|
n_tile_id += repack_stages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace marlin
|
||||||
|
|
||||||
|
#define CALL_IF(NUM_BITS) \
|
||||||
|
else if (num_bits == NUM_BITS) { \
|
||||||
|
cudaFuncSetAttribute( \
|
||||||
|
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
|
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||||
|
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
|
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||||
|
int64_t size_n, int64_t num_bits) {
|
||||||
|
// Verify compatibility with marlin tile of 16x64
|
||||||
|
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||||
|
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||||
|
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||||
|
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||||
|
|
||||||
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
|
int const pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
// Verify B
|
||||||
|
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
||||||
|
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
|
" is not size_k = ", size_k);
|
||||||
|
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
||||||
|
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
|
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
||||||
|
|
||||||
|
// Verify device and strides
|
||||||
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||||
|
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||||
|
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||||
|
|
||||||
|
// Alloc buffers
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||||
|
auto options = torch::TensorOptions()
|
||||||
|
.dtype(b_q_weight.dtype())
|
||||||
|
.device(b_q_weight.device());
|
||||||
|
torch::Tensor out = torch::empty(
|
||||||
|
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||||
|
options);
|
||||||
|
|
||||||
|
// Get ptrs
|
||||||
|
uint32_t const* b_q_weight_ptr =
|
||||||
|
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||||
|
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||||
|
|
||||||
|
// Get dev info
|
||||||
|
int dev = b_q_weight.get_device();
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
int blocks;
|
||||||
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
|
||||||
|
int max_shared_mem = 0;
|
||||||
|
cudaDeviceGetAttribute(&max_shared_mem,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||||
|
TORCH_CHECK(max_shared_mem > 0);
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
CALL_IF(4)
|
||||||
|
CALL_IF(8)
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -19,8 +19,8 @@
|
|||||||
* Adapted from https://github.com/IST-DASLab/marlin
|
* Adapted from https://github.com/IST-DASLab/marlin
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
#include "gptq_marlin_dtypes.cuh"
|
#include "marlin_dtypes.cuh"
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
static_assert(std::is_same<scalar_t, half>::value || \
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
@@ -32,7 +32,7 @@ inline std::string str(T x) {
|
|||||||
return std::to_string(x);
|
return std::to_string(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace marlin {
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
@@ -72,10 +72,11 @@ __global__ void Marlin(
|
|||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& perm, torch::Tensor& workspace,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_k, bool is_k_full) {
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||||
return torch::empty({1, 1});
|
return torch::empty({1, 1});
|
||||||
@@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
|
|||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-point dequantizers
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
|
||||||
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
|
||||||
|
int q) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
|
||||||
|
const int SUB = 0x64006400;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd400d400;
|
||||||
|
typename ScalarType<half>::FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant_4bit_zp<nv_bfloat16>(int q) {
|
||||||
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
q >>= 4;
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
|
||||||
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
static constexpr uint32_t MUL = 0x3F803F80;
|
||||||
|
static constexpr uint32_t ADD = 0xC300C300;
|
||||||
|
|
||||||
|
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
|
||||||
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
||||||
|
int q) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
typename ScalarType<half>::FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||||
|
dequant_8bit_zp<nv_bfloat16>(int q) {
|
||||||
|
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||||
|
|
||||||
|
float fp32_intermediates[4];
|
||||||
|
uint32_t* fp32_intermediates_casted =
|
||||||
|
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||||
|
|
||||||
|
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||||
|
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||||
|
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||||
|
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||||
|
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||||
|
|
||||||
|
fp32_intermediates[0] -= 8388608.f;
|
||||||
|
fp32_intermediates[1] -= 8388608.f;
|
||||||
|
fp32_intermediates[2] -= 8388608.f;
|
||||||
|
fp32_intermediates[3] -= 8388608.f;
|
||||||
|
|
||||||
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||||
|
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||||
|
fp32_intermediates_casted[1], 0x7632);
|
||||||
|
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||||
|
fp32_intermediates_casted[3], 0x7632);
|
||||||
|
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
|||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||||
|
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
|
||||||
|
int i) {
|
||||||
|
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||||
|
scalar_t2 zp =
|
||||||
|
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
||||||
|
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||||
|
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||||
|
}
|
||||||
|
|
||||||
// Same as above, but for act_order (each K is multiplied individually)
|
// Same as above, but for act_order (each K is multiplied individually)
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||||
@@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
|||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const bool has_act_order, // whether act_order is enabled
|
||||||
|
const bool has_zp, // whether zero-points are enabled
|
||||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||||
// with a separate quantization scale
|
// with a separate quantization scale
|
||||||
>
|
>
|
||||||
@@ -413,6 +534,8 @@ __global__ void Marlin(
|
|||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
// (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
|
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||||
|
// (k/groupsize)x(n/pack_factor)
|
||||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||||
int num_groups, // number of scale groups per output channel
|
int num_groups, // number of scale groups per output channel
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
@@ -437,6 +560,7 @@ __global__ void Marlin(
|
|||||||
using FragB = typename ScalarType<scalar_t>::FragB;
|
using FragB = typename ScalarType<scalar_t>::FragB;
|
||||||
using FragC = typename ScalarType<scalar_t>::FragC;
|
using FragC = typename ScalarType<scalar_t>::FragC;
|
||||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||||
|
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||||
|
|
||||||
constexpr int pack_factor = 32 / num_bits;
|
constexpr int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
@@ -566,6 +690,13 @@ __global__ void Marlin(
|
|||||||
int tb_n_warps = thread_n_blocks / 4;
|
int tb_n_warps = thread_n_blocks / 4;
|
||||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||||
|
|
||||||
|
// Zero-points sizes/strides
|
||||||
|
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||||
|
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||||
|
constexpr int zp_tb_groups = s_tb_groups;
|
||||||
|
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||||
|
int zp_gl_rd_delta = zp_gl_stride;
|
||||||
|
|
||||||
// Global A read index of current thread.
|
// Global A read index of current thread.
|
||||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
(threadIdx.x % a_gl_rd_delta_o);
|
||||||
@@ -605,6 +736,19 @@ __global__ void Marlin(
|
|||||||
int s_sh_wr = threadIdx.x;
|
int s_sh_wr = threadIdx.x;
|
||||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||||
|
|
||||||
|
// Zero-points
|
||||||
|
int zp_gl_rd;
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
if constexpr (group_blocks == -1) {
|
||||||
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||||
|
} else {
|
||||||
|
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||||
|
zp_sh_stride * slice_col + threadIdx.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int zp_sh_wr = threadIdx.x;
|
||||||
|
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||||
|
|
||||||
// We use a different scale layout for grouped and column-wise quantization as
|
// We use a different scale layout for grouped and column-wise quantization as
|
||||||
// we scale a `half2` tile in column-major layout in the former and in
|
// we scale a `half2` tile in column-major layout in the former and in
|
||||||
// row-major in the latter case.
|
// row-major in the latter case.
|
||||||
@@ -616,6 +760,18 @@ __global__ void Marlin(
|
|||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) % 4;
|
(threadIdx.x % 32) % 4;
|
||||||
|
|
||||||
|
// Zero-points have the same read layout as the scales
|
||||||
|
// (without column-wise case)
|
||||||
|
constexpr int num_col_threads = 8;
|
||||||
|
constexpr int num_row_threads = 4;
|
||||||
|
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||||
|
int zp_sh_rd;
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||||
|
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
|
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||||
|
}
|
||||||
|
|
||||||
// Precompute which thread should not read memory in which iterations; this is
|
// Precompute which thread should not read memory in which iterations; this is
|
||||||
// needed if there are more threads than required for a certain tilesize or
|
// needed if there are more threads than required for a certain tilesize or
|
||||||
// when the batchsize is not a multiple of 16.
|
// when the batchsize is not a multiple of 16.
|
||||||
@@ -664,14 +820,17 @@ __global__ void Marlin(
|
|||||||
int4* sh_a = sh;
|
int4* sh_a = sh;
|
||||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||||
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||||
|
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||||
|
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
I4 frag_b_quant[2][b_thread_vecs];
|
I4 frag_b_quant[2][b_thread_vecs];
|
||||||
FragC frag_c[thread_m_blocks][4][2];
|
FragC frag_c[thread_m_blocks][4][2];
|
||||||
FragS frag_s[2][4]; // No act-order
|
FragS frag_s[2][4]; // No act-order
|
||||||
FragS act_frag_s[2][4][4]; // For act-order
|
FragS act_frag_s[2][4][4]; // For act-order
|
||||||
|
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||||
|
FragZP frag_zp; // Zero-points in fp16
|
||||||
|
|
||||||
// Zero accumulators.
|
// Zero accumulators.
|
||||||
auto zero_accums = [&]() {
|
auto zero_accums = [&]() {
|
||||||
@@ -777,6 +936,28 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (has_zp && group_blocks != -1) {
|
||||||
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||||
|
|
||||||
|
if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
|
// Only fetch zero-points if this tile starts a new group
|
||||||
|
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||||
|
if (zp_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||||
|
}
|
||||||
|
zp_gl_rd += zp_gl_rd_delta;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < zp_tb_groups; i++) {
|
||||||
|
if (zp_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||||
|
&zp_ptr[zp_gl_rd]);
|
||||||
|
}
|
||||||
|
zp_gl_rd += zp_gl_rd_delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||||
@@ -784,6 +965,12 @@ __global__ void Marlin(
|
|||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto fetch_zp_to_shared = [&]() {
|
||||||
|
if (zp_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Wait until the next thread tile has been loaded to shared memory.
|
// Wait until the next thread tile has been loaded to shared memory.
|
||||||
auto wait_for_stage = [&]() {
|
auto wait_for_stage = [&]() {
|
||||||
// We only have `stages - 2` active fetches since we are double buffering
|
// We only have `stages - 2` active fetches since we are double buffering
|
||||||
@@ -932,8 +1119,73 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||||
|
if constexpr (!has_zp) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int pipe = full_pipe % stages;
|
||||||
|
|
||||||
|
if constexpr (group_blocks == -1) {
|
||||||
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
|
int4* sh_zp_stage =
|
||||||
|
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] =
|
||||||
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int warp_id = threadIdx.x / 32;
|
||||||
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
|
int warp_row = warp_id / n_warps;
|
||||||
|
|
||||||
|
int cur_k = warp_row * 16;
|
||||||
|
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||||
|
|
||||||
|
int k_blocks = cur_k / 16;
|
||||||
|
int cur_group_id = k_blocks / group_blocks;
|
||||||
|
|
||||||
|
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||||
|
|
||||||
|
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] =
|
||||||
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Execute the actual tensor core matmul of a sub-tile.
|
// Execute the actual tensor core matmul of a sub-tile.
|
||||||
auto matmul = [&](int k) {
|
auto matmul = [&](int k) {
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
FragB frag_zp_0;
|
||||||
|
FragB frag_zp_1;
|
||||||
|
if constexpr (num_bits == 4) {
|
||||||
|
int zp_quant = frag_qzp[k % 2][0];
|
||||||
|
int zp_quant_shift = zp_quant >> 8;
|
||||||
|
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
|
||||||
|
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int zp_quant_0 = frag_qzp[k % 2][0];
|
||||||
|
int zp_quant_1 = frag_qzp[k % 2][1];
|
||||||
|
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
|
||||||
|
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
frag_zp[0] = frag_zp_0[0];
|
||||||
|
frag_zp[1] = frag_zp_0[1];
|
||||||
|
frag_zp[2] = frag_zp_1[0];
|
||||||
|
frag_zp[3] = frag_zp_1[1];
|
||||||
|
}
|
||||||
|
|
||||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -944,16 +1196,32 @@ __global__ void Marlin(
|
|||||||
int b_quant = frag_b_quant[k % 2][0][j];
|
int b_quant = frag_b_quant[k % 2][0][j];
|
||||||
int b_quant_shift = b_quant >> 8;
|
int b_quant_shift = b_quant >> 8;
|
||||||
|
|
||||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
if constexpr (has_zp) {
|
||||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
|
||||||
|
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||||
|
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||||
|
|
||||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
if constexpr (has_zp) {
|
||||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
|
||||||
|
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
|
||||||
|
} else {
|
||||||
|
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||||
|
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply zero-point to frag_b0
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply scale to frag_b0
|
// Apply scale to frag_b0
|
||||||
@@ -967,6 +1235,11 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply zero-point to frag_b1
|
||||||
|
if constexpr (has_zp) {
|
||||||
|
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||||
|
}
|
||||||
|
|
||||||
// Apply scale to frag_b1
|
// Apply scale to frag_b1
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||||
@@ -1189,6 +1462,12 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (has_zp && group_blocks == -1) {
|
||||||
|
if (i == 0) {
|
||||||
|
fetch_zp_to_shared();
|
||||||
|
}
|
||||||
|
}
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
fetch_to_shared(i, i, i < slice_iters);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1197,6 +1476,7 @@ __global__ void Marlin(
|
|||||||
init_same_group(0);
|
init_same_group(0);
|
||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
fetch_scales_to_registers(0, 0);
|
fetch_scales_to_registers(0, 0);
|
||||||
|
fetch_zp_to_registers(0, 0);
|
||||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||||
};
|
};
|
||||||
@@ -1217,6 +1497,7 @@ __global__ void Marlin(
|
|||||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||||
fetch_to_registers(k + 1, pipe % stages);
|
fetch_to_registers(k + 1, pipe % stages);
|
||||||
fetch_scales_to_registers(k + 1, pipe);
|
fetch_scales_to_registers(k + 1, pipe);
|
||||||
|
fetch_zp_to_registers(k + 1, pipe);
|
||||||
if (k == b_sh_wr_iters - 2) {
|
if (k == b_sh_wr_iters - 2) {
|
||||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||||
slice_iters >= stages);
|
slice_iters >= stages);
|
||||||
@@ -1354,6 +1635,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
start_pipes();
|
start_pipes();
|
||||||
@@ -1363,22 +1645,24 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||||
|
NUM_THREADS) \
|
||||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||||
num_threads == NUM_THREADS) { \
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
GROUP_BLOCKS>, \
|
HAS_ZP, GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||||
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
HAS_ZP, GROUP_BLOCKS> \
|
||||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
prob_k, locks); \
|
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
|
||||||
|
prob_m, prob_n, prob_k, locks); \
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
|||||||
return exec_config_t{0, {-1, -1, -1}};
|
return exec_config_t{0, {-1, -1, -1}};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||||
|
|
||||||
|
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
|
\
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
|
\
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||||
|
\
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||||
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||||
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||||
int prob_n, int prob_k, void* workspace, int num_bits,
|
int prob_n, int prob_k, void* workspace, int num_bits,
|
||||||
bool has_act_order, bool is_k_full, int num_groups,
|
bool has_act_order, bool is_k_full, bool has_zp,
|
||||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
int num_groups, int group_size, int dev,
|
||||||
int thread_n, int sms, int max_par) {
|
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||||
|
int max_par) {
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
@@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
|||||||
const int4* B_ptr = (const int4*)B;
|
const int4* B_ptr = (const int4*)B;
|
||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const int4* s_ptr = (const int4*)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
const int* g_idx_ptr = (const int*)g_idx;
|
const int* g_idx_ptr = (const int*)g_idx;
|
||||||
const int* perm_ptr = (const int*)perm;
|
const int* perm_ptr = (const int*)perm;
|
||||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||||
@@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
|||||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define kernel configurations
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
CALL_IF(4, 32, 2, 256)
|
GPTQ_CALL_IF(4, 16, 4, 256)
|
||||||
CALL_IF(4, 16, 4, 256)
|
GPTQ_CALL_IF(4, 8, 8, 256)
|
||||||
CALL_IF(4, 8, 8, 256)
|
GPTQ_CALL_IF(4, 8, 4, 128)
|
||||||
CALL_IF(4, 8, 4, 128)
|
GPTQ_CALL_IF(4, 4, 8, 128)
|
||||||
CALL_IF(4, 4, 8, 128)
|
GPTQ_CALL_IF(8, 16, 4, 256)
|
||||||
CALL_IF(8, 32, 2, 256)
|
GPTQ_CALL_IF(8, 8, 8, 256)
|
||||||
CALL_IF(8, 16, 4, 256)
|
GPTQ_CALL_IF(8, 8, 4, 128)
|
||||||
CALL_IF(8, 8, 8, 256)
|
GPTQ_CALL_IF(8, 4, 8, 128)
|
||||||
CALL_IF(8, 8, 4, 128)
|
|
||||||
CALL_IF(8, 4, 8, 128)
|
AWQ_CALL_IF(4, 16, 4, 256)
|
||||||
|
AWQ_CALL_IF(4, 8, 8, 256)
|
||||||
|
AWQ_CALL_IF(4, 8, 4, 128)
|
||||||
|
AWQ_CALL_IF(4, 4, 8, 128)
|
||||||
|
AWQ_CALL_IF(8, 16, 4, 256)
|
||||||
|
AWQ_CALL_IF(8, 8, 8, 256)
|
||||||
|
AWQ_CALL_IF(8, 8, 4, 128)
|
||||||
|
AWQ_CALL_IF(8, 4, 8, 128)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||||
str(prob_n) + ", " + str(prob_k) + "]" +
|
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||||
", has_act_order = " + str(has_act_order) +
|
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||||
", num_groups = " + str(num_groups) +
|
", thread_m_blocks = ", thread_m_blocks,
|
||||||
", group_size = " + str(group_size) +
|
", thread_n_blocks = ", thread_n_blocks,
|
||||||
", thread_m_blocks = " + str(thread_m_blocks) +
|
", thread_k_blocks = ", thread_k_blocks,
|
||||||
", thread_n_blocks = " + str(thread_n_blocks) +
|
", num_bits = ", num_bits);
|
||||||
", thread_k_blocks = " + str(thread_k_blocks));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
||||||
@@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
|||||||
} // namespace gptq_marlin
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||||
torch::Tensor& perm, torch::Tensor& workspace,
|
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_k, bool is_k_full) {
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
|
bool is_k_full, bool has_zp) {
|
||||||
// Verify num_bits
|
// Verify num_bits
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
@@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
", size_k = ", size_k);
|
", size_k = ", size_k);
|
||||||
|
|
||||||
// Verify B
|
// Verify B
|
||||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
" is not divisible by tile_size = ", marlin::tile_size);
|
||||||
int actual_size_n =
|
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
|
||||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||||
", actual_size_n = ", actual_size_n);
|
", actual_size_n = ", actual_size_n);
|
||||||
|
|
||||||
@@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||||
|
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||||
|
|
||||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||||
|
|
||||||
@@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
int group_size = -1;
|
int group_size = -1;
|
||||||
bool has_act_order = g_idx.size(0) != 0;
|
bool has_act_order = g_idx.size(0) != 0;
|
||||||
|
|
||||||
int b_rank = b_scales.sizes().size();
|
int rank = b_scales.sizes().size();
|
||||||
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
|
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||||
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
||||||
" is not size_n = ", size_n);
|
" is not size_n = ", size_n);
|
||||||
num_groups = b_scales.size(0);
|
num_groups = b_scales.size(0);
|
||||||
@@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify b_zeros
|
||||||
|
if (has_zp) {
|
||||||
|
int rank = b_zeros.sizes().size();
|
||||||
|
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||||
|
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||||
|
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||||
|
" is not num_groups = ", num_groups);
|
||||||
|
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||||
|
"b_zeros dim 1 = ", b_scales.size(1),
|
||||||
|
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||||
|
}
|
||||||
|
|
||||||
// Verify workspace size
|
// Verify workspace size
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||||
int min_workspace_size =
|
|
||||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
|
||||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
"workspace.numel = ", workspace.numel(),
|
"workspace.numel = ", workspace.numel(),
|
||||||
" is below min_workspace_size = ", min_workspace_size);
|
" is below min_workspace_size = ", min_workspace_size);
|
||||||
|
|
||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
if (a.scalar_type() == at::ScalarType::Half) {
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
gptq_marlin::marlin_mm_f16i4<half>(
|
marlin::marlin_mm_f16i4<half>(
|
||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
|
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
|
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||||
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_n, sms, gptq_marlin::max_par);
|
thread_k, thread_n, sms, marlin::max_par);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
|
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||||
is_k_full, num_groups, group_size, dev,
|
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
gptq_marlin::max_par);
|
thread_k, thread_n, sms, marlin::max_par);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,16 @@
|
|||||||
#include "gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
|
|
||||||
namespace gptq_marlin {
|
|
||||||
|
|
||||||
static constexpr int repack_stages = 8;
|
|
||||||
|
|
||||||
static constexpr int repack_threads = 256;
|
|
||||||
|
|
||||||
static constexpr int tile_k_size = tile_size;
|
|
||||||
static constexpr int tile_n_size = tile_k_size * 4;
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void marlin_repack_kernel(
|
__global__ void gptq_marlin_repack_kernel(
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
int size_k, int size_n) {}
|
int size_k, int size_n) {}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
namespace marlin {
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||||
__global__ void marlin_repack_kernel(
|
__global__ void gptq_marlin_repack_kernel(
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||||
int size_k, int size_n) {
|
int size_k, int size_n) {
|
||||||
@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||||
NUM_BITS, HAS_PERM>, \
|
HAS_PERM>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||||
HAS_PERM> \
|
HAS_PERM> \
|
||||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits) {
|
int64_t num_bits) {
|
||||||
// Verify compatibility with marlin tile of 16x64
|
// Verify compatibility with marlin tile of 16x64
|
||||||
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||||
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||||
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||||
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||||
|
|
||||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||||
@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
auto options = torch::TensorOptions()
|
auto options = torch::TensorOptions()
|
||||||
.dtype(b_q_weight.dtype())
|
.dtype(b_q_weight.dtype())
|
||||||
.device(b_q_weight.device());
|
.device(b_q_weight.device());
|
||||||
torch::Tensor out =
|
torch::Tensor out = torch::empty(
|
||||||
torch::empty({size_k / gptq_marlin::tile_size,
|
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||||
size_n * gptq_marlin::tile_size / pack_factor},
|
options);
|
||||||
options);
|
|
||||||
|
|
||||||
// Detect if there is act_order
|
// Detect if there is act_order
|
||||||
bool has_perm = perm.size(0) != 0;
|
bool has_perm = perm.size(0) != 0;
|
||||||
|
|||||||
@@ -9,7 +9,9 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace marlin {
|
||||||
|
|
||||||
|
// Marlin params
|
||||||
|
|
||||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
|
|||||||
static constexpr int tile_size = 16;
|
static constexpr int tile_size = 16;
|
||||||
static constexpr int max_par = 16;
|
static constexpr int max_par = 16;
|
||||||
|
|
||||||
|
// Repack params
|
||||||
|
static constexpr int repack_stages = 8;
|
||||||
|
|
||||||
|
static constexpr int repack_threads = 256;
|
||||||
|
|
||||||
|
static constexpr int tile_k_size = tile_size;
|
||||||
|
static constexpr int tile_n_size = tile_k_size * 4;
|
||||||
|
|
||||||
|
// Helpers
|
||||||
template <typename T, int n>
|
template <typename T, int n>
|
||||||
struct Vec {
|
struct Vec {
|
||||||
T elems[n];
|
T elems[n];
|
||||||
@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
|
|
||||||
#ifndef _data_types_cuh
|
#ifndef _data_types_cuh
|
||||||
#define _data_types_cuh
|
#define _data_types_cuh
|
||||||
#include "gptq_marlin.cuh"
|
#include "marlin.cuh"
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
namespace gptq_marlin {
|
namespace marlin {
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
class ScalarType {};
|
class ScalarType {};
|
||||||
@@ -23,6 +23,7 @@ class ScalarType<half> {
|
|||||||
using FragB = Vec<half2, 2>;
|
using FragB = Vec<half2, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>;
|
using FragS = Vec<half2, 1>;
|
||||||
|
using FragZP = Vec<half2, 4>;
|
||||||
|
|
||||||
static __device__ float inline num2float(const half x) {
|
static __device__ float inline num2float(const half x) {
|
||||||
return __half2float(x);
|
return __half2float(x);
|
||||||
@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
|
|||||||
using FragB = Vec<nv_bfloat162, 2>;
|
using FragB = Vec<nv_bfloat162, 2>;
|
||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<nv_bfloat162, 1>;
|
using FragS = Vec<nv_bfloat162, 1>;
|
||||||
|
using FragZP = Vec<nv_bfloat162, 4>;
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||||
@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
|
|||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gptq_marlin
|
} // namespace marlin
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@@ -30,7 +30,7 @@ inline std::string str(T x) {
|
|||||||
return std::to_string(x);
|
return std::to_string(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace marlin {
|
namespace marlin_dense {
|
||||||
|
|
||||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace marlin
|
} // namespace marlin_dense
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
TORCH_CHECK(size_k == a.size(1),
|
TORCH_CHECK(size_k == a.size(1),
|
||||||
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
||||||
", size_k = " + str(size_k));
|
", size_k = " + str(size_k));
|
||||||
TORCH_CHECK(size_k % marlin::tile_size == 0,
|
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
|
||||||
"size_k = " + str(size_k) +
|
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
||||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
str(marlin_dense::tile_size));
|
||||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
|
||||||
"Shape mismatch: b_q_weight.size(0) = " +
|
"Shape mismatch: b_q_weight.size(0) = " +
|
||||||
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
||||||
", tile_size = " + str(marlin::tile_size));
|
", tile_size = " + str(marlin_dense::tile_size));
|
||||||
|
|
||||||
// Verify N
|
// Verify N
|
||||||
TORCH_CHECK(b_scales.size(1) == size_n,
|
TORCH_CHECK(b_scales.size(1) == size_n,
|
||||||
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
||||||
", size_n = " + str(size_n));
|
", size_n = " + str(size_n));
|
||||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
TORCH_CHECK(
|
||||||
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
b_q_weight.size(1) % marlin_dense::tile_size == 0,
|
||||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
||||||
|
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
|
||||||
|
|
||||||
int actual_size_n =
|
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
|
||||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
marlin_dense::pack_factor_4bit;
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
size_n == actual_size_n,
|
size_n == actual_size_n,
|
||||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||||
@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
"Unexpected groupsize = " + str(groupsize));
|
"Unexpected groupsize = " + str(groupsize));
|
||||||
|
|
||||||
// Verify workspace size
|
// Verify workspace size
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
|
||||||
size_n % marlin::min_thread_n == 0,
|
"size_n = " + str(size_n) +
|
||||||
"size_n = " + str(size_n) +
|
", is not divisible by min_thread_n = " +
|
||||||
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
|
str(marlin_dense::min_thread_n));
|
||||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
int min_workspace_size =
|
||||||
|
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
|
||||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
"workspace.numel = " + str(workspace.numel()) +
|
"workspace.numel = " + str(workspace.numel()) +
|
||||||
" is below min_workspace_size = " + str(min_workspace_size));
|
" is below min_workspace_size = " + str(min_workspace_size));
|
||||||
|
|
||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
||||||
b_scales.data_ptr(), size_m, size_n, size_k,
|
b_scales.data_ptr(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), groupsize, dev,
|
workspace.data_ptr(), groupsize, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
|
at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||||
sms, marlin::max_par);
|
thread_n, sms, marlin_dense::max_par);
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
||||||
@@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||||
@@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||||
|
|
||||||
|
// prepare_inputs advance_step
|
||||||
|
ops.def("advance_step", &advance_step);
|
||||||
|
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
@@ -137,6 +141,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
||||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||||
|
|
||||||
|
// awq_marlin repack from AWQ.
|
||||||
|
ops.def("awq_marlin_repack", &awq_marlin_repack);
|
||||||
|
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||||
|
|
||||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||||
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||||
@@ -175,12 +183,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
||||||
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||||
|
|
||||||
// Compute FP8 quantized tensor and scaling factor.
|
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||||
"()");
|
"()");
|
||||||
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
||||||
|
|
||||||
|
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
|
||||||
|
"scale, Tensor? scale_ub) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||||
|
&dynamic_per_token_scaled_fp8_quant);
|
||||||
|
|
||||||
// Aligning the number of tokens to be processed by each expert such
|
// Aligning the number of tokens to be processed by each expert such
|
||||||
// that it is divisible by the block size.
|
// that it is divisible by the block size.
|
||||||
ops.def(
|
ops.def(
|
||||||
@@ -223,7 +239,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor! key_cache, Tensor! value_cache,"
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
" Tensor slot_mapping,"
|
" Tensor slot_mapping,"
|
||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" float kv_scale) -> ()");
|
" float k_scale, float v_scale) -> ()");
|
||||||
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
||||||
|
|
||||||
// Reshape the key and value tensors and cache them.
|
// Reshape the key and value tensors and cache them.
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ sphinx==6.2.1
|
|||||||
sphinx-book-theme==1.0.1
|
sphinx-book-theme==1.0.1
|
||||||
sphinx-copybutton==0.5.2
|
sphinx-copybutton==0.5.2
|
||||||
myst-parser==2.0.0
|
myst-parser==2.0.0
|
||||||
sphinx-argparse
|
sphinx-argparse==0.4.0
|
||||||
|
|
||||||
# packages to install to build the documentation
|
# packages to install to build the documentation
|
||||||
pydantic
|
pydantic
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ vLLM is a community project. Our compute resources for development and testing a
|
|||||||
- Databricks
|
- Databricks
|
||||||
- DeepInfra
|
- DeepInfra
|
||||||
- Dropbox
|
- Dropbox
|
||||||
|
- Google Cloud
|
||||||
- Lambda Lab
|
- Lambda Lab
|
||||||
- NVIDIA
|
- NVIDIA
|
||||||
- Replicate
|
- Replicate
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ Multi-Modality
|
|||||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||||
|
|
||||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
||||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
|
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
|
||||||
|
|
||||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||||
by following :ref:`this guide <adding_multimodal_plugin>`.
|
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
LLM Inputs
|
LLM Inputs
|
||||||
==========
|
==========
|
||||||
|
|
||||||
.. autodata:: vllm.inputs.PromptStrictInputs
|
.. autodata:: vllm.inputs.PromptInputs
|
||||||
|
|
||||||
.. autoclass:: vllm.inputs.TextPrompt
|
.. autoclass:: vllm.inputs.TextPrompt
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
Installation with ROCm
|
Installation with ROCm
|
||||||
======================
|
======================
|
||||||
|
|
||||||
vLLM supports AMD GPUs with ROCm 5.7 and 6.0.
|
vLLM supports AMD GPUs with ROCm 6.1.
|
||||||
|
|
||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
@@ -11,7 +11,7 @@ Requirements
|
|||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Python: 3.8 -- 3.11
|
* Python: 3.8 -- 3.11
|
||||||
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||||
* ROCm 6.0 and ROCm 5.7
|
* ROCm 6.1
|
||||||
|
|
||||||
Installation options:
|
Installation options:
|
||||||
|
|
||||||
@@ -27,10 +27,10 @@ You can build and install vLLM from source.
|
|||||||
|
|
||||||
First, build a docker image from `Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ and launch a docker container from the image.
|
First, build a docker image from `Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ and launch a docker container from the image.
|
||||||
|
|
||||||
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ uses ROCm 6.0 by default, but also supports ROCm 5.7.
|
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ uses ROCm 6.1 by default, but also supports ROCm 5.7 and 6.0 in older vLLM branches.
|
||||||
It provides flexibility to customize the build of docker image using the following arguments:
|
It provides flexibility to customize the build of docker image using the following arguments:
|
||||||
|
|
||||||
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
|
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image.
|
||||||
* `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
|
* `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
|
||||||
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||||
* `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `ae7928c`
|
* `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `ae7928c`
|
||||||
@@ -39,24 +39,17 @@ It provides flexibility to customize the build of docker image using the followi
|
|||||||
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
||||||
|
|
||||||
|
|
||||||
To build vllm on ROCm 6.0 for MI200 and MI300 series, you can use the default:
|
To build vllm on ROCm 6.1 for MI200 and MI300 series, you can use the default:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
$ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||||
|
|
||||||
To build vllm on ROCm 6.0 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below:
|
To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
|
$ DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
|
||||||
|
|
||||||
To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
|
|
||||||
-f Dockerfile.rocm -t vllm-rocm .
|
|
||||||
|
|
||||||
To run the above docker image ``vllm-rocm``, use the below command:
|
To run the above docker image ``vllm-rocm``, use the below command:
|
||||||
|
|
||||||
@@ -85,39 +78,24 @@ Option 2: Build from source
|
|||||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||||
|
|
||||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||||
- `Pytorch <https://pytorch.org/>`_
|
- `PyTorch <https://pytorch.org/>`_
|
||||||
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
||||||
|
|
||||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
|
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`.
|
||||||
|
|
||||||
Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started <https://pytorch.org/get-started/locally/>`_
|
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guild in PyTorch `Getting Started <https://pytorch.org/get-started/locally/>`_
|
||||||
|
|
||||||
For rocm6.0:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.0
|
|
||||||
|
|
||||||
|
|
||||||
For rocm5.7:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ pip install torch --index-url https://download.pytorch.org/whl/rocm5.7
|
|
||||||
|
|
||||||
|
|
||||||
1. Install `Triton flash attention for ROCm <https://github.com/ROCm/triton>`_
|
1. Install `Triton flash attention for ROCm <https://github.com/ROCm/triton>`_
|
||||||
|
|
||||||
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
|
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
|
||||||
|
|
||||||
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_
|
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_
|
||||||
|
|
||||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
|
||||||
|
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
|
||||||
- If you fail to install `ROCm/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
|
||||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
|
||||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||||
|
|
||||||
3. Build vLLM.
|
3. Build vLLM.
|
||||||
@@ -131,7 +109,7 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
|
|||||||
|
|
||||||
.. tip::
|
.. tip::
|
||||||
|
|
||||||
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
|
|
||||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
||||||
- To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
|
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
|
||||||
- The ROCm version of pytorch, ideally, should match the ROCm driver version.
|
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
|
||||||
|
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
|
||||||
|
|||||||
@@ -19,9 +19,6 @@ If you have already taken care of the above issues, but the vLLM instance still
|
|||||||
- Set the environment variable ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
|
- Set the environment variable ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
|
||||||
- Set the environment variable ``export VLLM_TRACE_FUNCTION=1``. All the function calls in vLLM will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
- Set the environment variable ``export VLLM_TRACE_FUNCTION=1``. All the function calls in vLLM will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
||||||
|
|
||||||
.. warning::
|
|
||||||
vLLM function tracing will generate a lot of logs and slow down the system. Only use it for debugging purposes.
|
|
||||||
|
|
||||||
With more logging, hopefully you can find the root cause of the issue.
|
With more logging, hopefully you can find the root cause of the issue.
|
||||||
|
|
||||||
If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
|
If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
|
||||||
@@ -67,3 +64,7 @@ Here are some common issues that can cause hangs:
|
|||||||
If the script runs successfully, you should see the message ``sanity check is successful!``.
|
If the script runs successfully, you should see the message ``sanity check is successful!``.
|
||||||
|
|
||||||
If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.
|
If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
After you find the root cause and solve the issue, remember to turn off all the debugging environment variables defined above, or simply start a new shell to avoid being affected by the debugging settings. If you don't do this, the system might be slow because many debugging functionalities are turned on.
|
||||||
|
|||||||
@@ -49,11 +49,10 @@ You can install vLLM using pip:
|
|||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ export VLLM_VERSION=0.5.2 # vLLM's main branch version is currently set to latest released tag
|
$ export VLLM_VERSION=0.5.2 # vLLM's main branch version is currently set to latest released tag
|
||||||
$ export PYTHON_VERSION=310
|
$ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl
|
||||||
$ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
|
||||||
$ # You can also access a specific commit
|
$ # You can also access a specific commit
|
||||||
$ # export VLLM_COMMIT=...
|
$ # export VLLM_COMMIT=...
|
||||||
$ # pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_COMMIT}/vllm-${VLLM_VERSION}-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
$ # pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_COMMIT}/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl
|
||||||
|
|
||||||
|
|
||||||
.. _build_from_source:
|
.. _build_from_source:
|
||||||
|
|||||||
@@ -73,16 +73,13 @@ Start the server:
|
|||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ python -m vllm.entrypoints.openai.api_server \
|
$ vllm serve facebook/opt-125m
|
||||||
$ --model facebook/opt-125m
|
|
||||||
|
|
||||||
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ python -m vllm.entrypoints.openai.api_server \
|
$ vllm serve facebook/opt-125m --chat-template ./examples/template_chatml.jinja
|
||||||
$ --model facebook/opt-125m \
|
|
||||||
$ --chat-template ./examples/template_chatml.jinja
|
|
||||||
|
|
||||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ First, install the dependencies:
|
|||||||
$ pip uninstall torch torch-xla -y
|
$ pip uninstall torch torch-xla -y
|
||||||
|
|
||||||
$ # Install PyTorch and PyTorch XLA.
|
$ # Install PyTorch and PyTorch XLA.
|
||||||
$ export DATE="+20240601"
|
$ export DATE="+20240713"
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ Next, build vLLM from source. This will only take a few seconds:
|
|||||||
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory
|
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory
|
||||||
|
|
||||||
|
|
||||||
You can install OpenBLAS with the following command:
|
Please install OpenBLAS with the following command:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ Just add the following lines in your code:
|
|||||||
from your_code import YourModelForCausalLM
|
from your_code import YourModelForCausalLM
|
||||||
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||||
|
|
||||||
If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:
|
If you are running api server with :code:`vllm serve <args>`, you can wrap the entrypoint with the following code:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -124,4 +124,4 @@ If you are running api server with `python -m vllm.entrypoints.openai.api_server
|
|||||||
import runpy
|
import runpy
|
||||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||||
|
|
||||||
Save the above code in a file and run it with `python your_file.py args`.
|
Save the above code in a file and run it with :code:`python your_file.py <args>`.
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
|||||||
.. argparse::
|
.. argparse::
|
||||||
:module: vllm.engine.arg_utils
|
:module: vllm.engine.arg_utils
|
||||||
:func: _engine_args_parser
|
:func: _engine_args_parser
|
||||||
:prog: -m vllm.entrypoints.openai.api_server
|
:prog: vllm serve
|
||||||
:nodefaultconst:
|
:nodefaultconst:
|
||||||
|
|
||||||
Async Engine Arguments
|
Async Engine Arguments
|
||||||
@@ -19,5 +19,5 @@ Below are the additional arguments related to the asynchronous engine:
|
|||||||
.. argparse::
|
.. argparse::
|
||||||
:module: vllm.engine.arg_utils
|
:module: vllm.engine.arg_utils
|
||||||
:func: _async_engine_args_parser
|
:func: _async_engine_args_parser
|
||||||
:prog: -m vllm.entrypoints.openai.api_server
|
:prog: vllm serve
|
||||||
:nodefaultconst:
|
:nodefaultconst:
|
||||||
@@ -61,10 +61,12 @@ LoRA adapted models can also be served with the Open-AI compatible vLLM server.
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python -m vllm.entrypoints.openai.api_server \
|
vllm serve meta-llama/Llama-2-7b-hf \
|
||||||
--model meta-llama/Llama-2-7b-hf \
|
|
||||||
--enable-lora \
|
--enable-lora \
|
||||||
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
|
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The commit ID `0dfa347e8877a4d4ed19ee56c140fa518470028c` may change over time. Please check the latest commit ID in your environment to ensure you are using the correct one.
|
||||||
|
|
||||||
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
|
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
|
||||||
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
|
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ Decoder-only Language Models
|
|||||||
- :code:`ai21labs/Jamba-v0.1`, etc.
|
- :code:`ai21labs/Jamba-v0.1`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
|
||||||
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`MiniCPMForCausalLM`
|
* - :code:`MiniCPMForCausalLM`
|
||||||
- MiniCPM
|
- MiniCPM
|
||||||
@@ -182,6 +182,10 @@ Vision Language Models
|
|||||||
- Models
|
- Models
|
||||||
- Example HuggingFace Models
|
- Example HuggingFace Models
|
||||||
- :ref:`LoRA <lora>`
|
- :ref:`LoRA <lora>`
|
||||||
|
* - :code:`ChameleonForConditionalGeneration`
|
||||||
|
- Chameleon
|
||||||
|
- :code:`facebook/chameleon-7b` etc.
|
||||||
|
-
|
||||||
* - :code:`FuyuForCausalLM`
|
* - :code:`FuyuForCausalLM`
|
||||||
- Fuyu
|
- Fuyu
|
||||||
- :code:`adept/fuyu-8b` etc.
|
- :code:`adept/fuyu-8b` etc.
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
|
|||||||
internally for each model.
|
internally for each model.
|
||||||
|
|
||||||
|
|
||||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
|
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
|
||||||
|
|
||||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||||
@@ -94,9 +94,7 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python -m vllm.entrypoints.openai.api_server \
|
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||||
--model llava-hf/llava-1.5-7b-hf \
|
|
||||||
--chat-template template_llava.jinja
|
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ Next, to install the required packages, add the following to your cerebrium.toml
|
|||||||
|
|
||||||
.. code-block:: toml
|
.. code-block:: toml
|
||||||
|
|
||||||
|
[cerebrium.deployment]
|
||||||
|
docker_base_image_url = "nvidia/cuda:12.1.1-runtime-ubuntu22.04"
|
||||||
|
|
||||||
[cerebrium.dependencies.pip]
|
[cerebrium.dependencies.pip]
|
||||||
vllm = "latest"
|
vllm = "latest"
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ Next, to provision a VM instance with LLM of your choice(`NousResearch/Llama-2-7
|
|||||||
gpu: 24GB
|
gpu: 24GB
|
||||||
commands:
|
commands:
|
||||||
- pip install vllm
|
- pip install vllm
|
||||||
- python -m vllm.entrypoints.openai.api_server --model $MODEL --port 8000
|
- vllm serve $MODEL --port 8000
|
||||||
model:
|
model:
|
||||||
format: openai
|
format: openai
|
||||||
type: chat
|
type: chat
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
.. _distributed_serving:
|
.. _distributed_serving:
|
||||||
|
|
||||||
|
Distributed Inference and Serving
|
||||||
|
=================================
|
||||||
|
|
||||||
How to decide the distributed inference strategy?
|
How to decide the distributed inference strategy?
|
||||||
=================================================
|
-------------------------------------------------
|
||||||
|
|
||||||
Before going into the details of distributed inference and serving, let's first make it clear when to use distributed inference and what are the strategies available. The common practice is:
|
Before going into the details of distributed inference and serving, let's first make it clear when to use distributed inference and what are the strategies available. The common practice is:
|
||||||
|
|
||||||
@@ -16,8 +19,8 @@ After adding enough GPUs and nodes to hold the model, you can run vLLM first, wh
|
|||||||
.. note::
|
.. note::
|
||||||
There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs.
|
There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs.
|
||||||
|
|
||||||
Distributed Inference and Serving
|
Details for Distributed Inference and Serving
|
||||||
=================================
|
----------------------------------------------
|
||||||
|
|
||||||
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We also support pipeline parallel as a beta feature for online serving. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
|
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We also support pipeline parallel as a beta feature for online serving. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
|
||||||
|
|
||||||
@@ -35,36 +38,73 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
|
|||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ python -m vllm.entrypoints.openai.api_server \
|
$ vllm serve facebook/opt-13b \
|
||||||
$ --model facebook/opt-13b \
|
|
||||||
$ --tensor-parallel-size 4
|
$ --tensor-parallel-size 4
|
||||||
|
|
||||||
You can also additionally specify :code:`--pipeline-parallel-size` to enable pipeline parallelism. For example, to run API server on 8 GPUs with pipeline parallelism and tensor parallelism:
|
You can also additionally specify :code:`--pipeline-parallel-size` to enable pipeline parallelism. For example, to run API server on 8 GPUs with pipeline parallelism and tensor parallelism:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ python -m vllm.entrypoints.openai.api_server \
|
$ vllm serve gpt2 \
|
||||||
$ --model gpt2 \
|
|
||||||
$ --tensor-parallel-size 4 \
|
$ --tensor-parallel-size 4 \
|
||||||
$ --pipeline-parallel-size 2 \
|
$ --pipeline-parallel-size 2
|
||||||
$ --distributed-executor-backend ray
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Pipeline parallel is a beta feature. It is only supported for online serving and the ray backend for now, as well as LLaMa and GPT2 style models.
|
Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, and Mixtral style models.
|
||||||
|
|
||||||
To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
Multi-Node Inference and Serving
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
If a single node does not have enough GPUs to hold the model, you can run the model using multiple nodes. It is important to make sure the execution environment is the same on all nodes, including the model path, the Python environment. The recommended way is to use docker images to ensure the same environment, and hide the heterogeneity of the host machines via mapping them into the same docker configuration.
|
||||||
|
|
||||||
|
The first step, is to start containers and organize them into a cluster. We have provided a helper `script <https://github.com/vllm-project/vllm/tree/main/examples/run_cluster.sh>`_ to start the cluster.
|
||||||
|
|
||||||
|
Pick a node as the head node, and run the following command:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ pip install ray
|
$ bash run_cluster.sh \
|
||||||
|
$ vllm/vllm-openai \
|
||||||
|
$ ip_of_head_node \
|
||||||
|
$ --head \
|
||||||
|
$ /path/to/the/huggingface/home/in/this/node
|
||||||
|
|
||||||
$ # On head node
|
On the rest of the worker nodes, run the following command:
|
||||||
$ ray start --head
|
|
||||||
|
|
||||||
$ # On worker nodes
|
.. code-block:: console
|
||||||
$ ray start --address=<ray-head-address>
|
|
||||||
|
|
||||||
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` multiplied by :code:`pipeline_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.
|
$ bash run_cluster.sh \
|
||||||
|
$ vllm/vllm-openai \
|
||||||
|
$ ip_of_head_node \
|
||||||
|
$ --worker \
|
||||||
|
$ /path/to/the/huggingface/home/in/this/node
|
||||||
|
|
||||||
|
Then you get a ray cluster of containers. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster.
|
||||||
|
|
||||||
|
Then, on any node, use ``docker exec -it node /bin/bash`` to enter the container, execute ``ray status`` to check the status of the Ray cluster. You should see the right number of nodes and GPUs.
|
||||||
|
|
||||||
|
After that, on any node, you can use vLLM as usual, just as you have all the GPUs on one node. The common practice is to set the tensor parallel size to the number of GPUs in each node, and the pipeline parallel size to the number of nodes. For example, if you have 16 GPUs in 2 nodes (8GPUs per node), you can set the tensor parallel size to 8 and the pipeline parallel size to 2:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ vllm serve /path/to/the/model/in/the/container \
|
||||||
|
$ --tensor-parallel-size 8 \
|
||||||
|
$ --pipeline-parallel-size 2
|
||||||
|
|
||||||
|
You can also use tensor parallel without pipeline parallel, just set the tensor parallel size to the number of GPUs in the cluster. For example, if you have 16 GPUs in 2 nodes (8GPUs per node), you can set the tensor parallel size to 16:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ vllm serve /path/to/the/model/in/the/container \
|
||||||
|
$ --tensor-parallel-size 16
|
||||||
|
|
||||||
|
To make tensor parallel performant, you should make sure the communication between nodes is efficient, e.g. using high-speed network cards like Infiniband. To correctly set up the cluster to use Infiniband, append additional arguments like ``--privileged -e NCCL_IB_HCA=mlx5`` to the ``run_cluster.sh`` script. Please contact your system administrator for more information on how to set up the flags. One way to confirm if the Infiniband is working is to run vLLM with ``NCCL_DEBUG=TRACE`` environment variable set, e.g. ``NCCL_DEBUG=TRACE vllm serve ...`` and check the logs for the NCCL version and the network used. If you find ``[send] via NET/Socket`` in the logs, it means NCCL uses raw TCP Socket, which is not efficient for cross-node tensor parallel. If you find ``[send] via NET/IB/GDRDMA`` in the logs, it means NCCL uses Infiniband with GPU-Direct RDMA, which is efficient.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
Please make sure you downloaded the model to all the nodes, or the model is downloaded to some distributed file system that is accessible by all nodes.
|
After you start the Ray cluster, you'd better also check the GPU-GPU communication between nodes. It can be non-trivial to set up. Please refer to the `sanity check script <https://docs.vllm.ai/en/latest/getting_started/debugging.html>`_ for more information.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Please make sure you downloaded the model to all the nodes (with the same path), or the model is downloaded to some distributed file system that is accessible by all nodes.
|
||||||
|
|
||||||
|
When you use huggingface repo id to refer to the model, you should append your huggingface token to the ``run_cluster.sh`` script, e.g. ``-e HF_TOKEN=``. The recommended way is to download the model first, and then use the path to refer to the model.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat
|
|||||||
|
|
||||||
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
|
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
|
||||||
```bash
|
```bash
|
||||||
python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123
|
vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123
|
||||||
```
|
```
|
||||||
|
|
||||||
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
|
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
|
||||||
@@ -97,9 +97,7 @@ template, or the template in string form. Without a chat template, the server wi
|
|||||||
and all chat requests will error.
|
and all chat requests will error.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m vllm.entrypoints.openai.api_server \
|
vllm serve <model> --chat-template ./path-to-chat-template.jinja
|
||||||
--model ... \
|
|
||||||
--chat-template ./path-to-chat-template.jinja
|
|
||||||
```
|
```
|
||||||
|
|
||||||
vLLM community provides a set of chat templates for popular models. You can find them in the examples
|
vLLM community provides a set of chat templates for popular models. You can find them in the examples
|
||||||
@@ -110,7 +108,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
|||||||
```{argparse}
|
```{argparse}
|
||||||
:module: vllm.entrypoints.openai.cli_args
|
:module: vllm.entrypoints.openai.cli_args
|
||||||
:func: create_parser_for_docs
|
:func: create_parser_for_docs
|
||||||
:prog: -m vllm.entrypoints.openai.api_server
|
:prog: vllm serve
|
||||||
```
|
```
|
||||||
|
|
||||||
## Tool calling in the chat completion API
|
## Tool calling in the chat completion API
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Example Python client for vllm.entrypoints.api_server
|
"""Example Python client for `vllm.entrypoints.api_server`
|
||||||
NOTE: The API server is used only for demonstration and simple performance
|
NOTE: The API server is used only for demonstration and simple performance
|
||||||
benchmarks. It is not intended for production use.
|
benchmarks. It is not intended for production use.
|
||||||
For production use, we recommend vllm.entrypoints.openai.api_server
|
For production use, we recommend `vllm serve` and the OpenAI client API.
|
||||||
and the OpenAI client API
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
22
examples/cpu_offload.py
Normal file
22
examples/cpu_offload.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10)
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
@@ -1,12 +1,5 @@
|
|||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
|
||||||
# You can use `.buildkite/download-images.sh` to download them
|
|
||||||
|
|
||||||
|
|
||||||
def run_llava():
|
def run_llava():
|
||||||
@@ -14,7 +7,7 @@ def run_llava():
|
|||||||
|
|
||||||
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||||
|
|
||||||
image = Image.open("images/stop_sign.jpg")
|
image = ImageAsset("stop_sign").pil_image
|
||||||
|
|
||||||
outputs = llm.generate({
|
outputs = llm.generate({
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@@ -28,25 +21,5 @@ def run_llava():
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
run_llava()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Download from s3
|
run_llava()
|
||||||
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
|
||||||
local_directory = "images"
|
|
||||||
|
|
||||||
# Make sure the local directory exists or create it
|
|
||||||
os.makedirs(local_directory, exist_ok=True)
|
|
||||||
|
|
||||||
# Use AWS CLI to sync the directory, assume anonymous access
|
|
||||||
subprocess.check_call([
|
|
||||||
"aws",
|
|
||||||
"s3",
|
|
||||||
"sync",
|
|
||||||
s3_bucket_path,
|
|
||||||
local_directory,
|
|
||||||
"--no-sign-request",
|
|
||||||
])
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ to the path of the custom logging configuration JSON file:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
||||||
python3 -m vllm.entrypoints.openai.api_server \
|
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
|
||||||
--max-model-len 2048 \
|
|
||||||
--model mistralai/Mistral-7B-v0.1
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -152,9 +150,7 @@ to the path of the custom logging configuration JSON file:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
||||||
python3 -m vllm.entrypoints.openai.api_server \
|
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
|
||||||
--max-model-len 2048 \
|
|
||||||
--model mistralai/Mistral-7B-v0.1
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -167,9 +163,7 @@ loggers.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
VLLM_CONFIGURE_LOGGING=0 \
|
VLLM_CONFIGURE_LOGGING=0 \
|
||||||
python3 -m vllm.entrypoints.openai.api_server \
|
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
|
||||||
--max-model-len 2048 \
|
|
||||||
--model mistralai/Mistral-7B-v0.1
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
28
examples/offline_inference_tpu.py
Normal file
28
examples/offline_inference_tpu.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"A robot may not injure a human being",
|
||||||
|
"It is only with the heart that one can see rightly;",
|
||||||
|
"The greatest glory in living lies not in never falling,",
|
||||||
|
]
|
||||||
|
answers = [
|
||||||
|
" or, through inaction, allow a human being to come to harm.",
|
||||||
|
" what is essential is invisible to the eye.",
|
||||||
|
" but in rising every time we fall.",
|
||||||
|
]
|
||||||
|
N = 1
|
||||||
|
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
|
||||||
|
sampling_params = SamplingParams(temperature=0.7,
|
||||||
|
top_p=1.0,
|
||||||
|
n=N,
|
||||||
|
max_tokens=16)
|
||||||
|
|
||||||
|
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||||
|
# In real workloads, `enforace_eager` should be `False`.
|
||||||
|
llm = LLM(model="google/gemma-2b", enforce_eager=True)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
for output, answer in zip(outputs, answers):
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
assert generated_text.startswith(answer)
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
"""An example showing how to use vLLM to serve VLMs.
|
"""An example showing how to use vLLM to serve VLMs.
|
||||||
|
|
||||||
Launch the vLLM server with the following command:
|
Launch the vLLM server with the following command:
|
||||||
python -m vllm.entrypoints.openai.api_server \
|
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||||
--model llava-hf/llava-1.5-7b-hf \
|
|
||||||
--chat-template template_llava.jinja
|
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
|
||||||
# You can use `.buildkite/download-images.sh` to download them
|
|
||||||
|
|
||||||
|
|
||||||
def run_paligemma():
|
def run_paligemma():
|
||||||
@@ -14,7 +7,7 @@ def run_paligemma():
|
|||||||
|
|
||||||
prompt = "caption es"
|
prompt = "caption es"
|
||||||
|
|
||||||
image = Image.open("images/stop_sign.jpg")
|
image = ImageAsset("stop_sign").pil_image
|
||||||
|
|
||||||
outputs = llm.generate({
|
outputs = llm.generate({
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@@ -28,25 +21,5 @@ def run_paligemma():
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
run_paligemma()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Download from s3
|
run_paligemma()
|
||||||
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
|
||||||
local_directory = "images"
|
|
||||||
|
|
||||||
# Make sure the local directory exists or create it
|
|
||||||
os.makedirs(local_directory, exist_ok=True)
|
|
||||||
|
|
||||||
# Use AWS CLI to sync the directory, assume anonymous access
|
|
||||||
subprocess.check_call([
|
|
||||||
"aws",
|
|
||||||
"s3",
|
|
||||||
"sync",
|
|
||||||
s3_bucket_path,
|
|
||||||
local_directory,
|
|
||||||
"--no-sign-request",
|
|
||||||
])
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
|
||||||
# You can use `.buildkite/download-images.sh` to download them
|
|
||||||
|
|
||||||
|
|
||||||
def run_phi3v():
|
def run_phi3v():
|
||||||
@@ -24,7 +17,7 @@ def run_phi3v():
|
|||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
image = Image.open("images/cherry_blossom.jpg")
|
image = ImageAsset("cherry_blossom").pil_image
|
||||||
|
|
||||||
# single-image prompt
|
# single-image prompt
|
||||||
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
|
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
|
||||||
@@ -44,19 +37,4 @@ def run_phi3v():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
|
||||||
local_directory = "images"
|
|
||||||
|
|
||||||
# Make sure the local directory exists or create it
|
|
||||||
os.makedirs(local_directory, exist_ok=True)
|
|
||||||
|
|
||||||
# Use AWS CLI to sync the directory, assume anonymous access
|
|
||||||
subprocess.check_call([
|
|
||||||
"aws",
|
|
||||||
"s3",
|
|
||||||
"sync",
|
|
||||||
s3_bucket_path,
|
|
||||||
local_directory,
|
|
||||||
"--no-sign-request",
|
|
||||||
])
|
|
||||||
run_phi3v()
|
run_phi3v()
|
||||||
|
|||||||
@@ -36,7 +36,7 @@
|
|||||||
```
|
```
|
||||||
export OTEL_SERVICE_NAME="vllm-server"
|
export OTEL_SERVICE_NAME="vllm-server"
|
||||||
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
|
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
|
||||||
python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
||||||
```
|
```
|
||||||
|
|
||||||
1. In a new shell, send requests with trace context from a dummy client
|
1. In a new shell, send requests with trace context from a dummy client
|
||||||
@@ -62,7 +62,7 @@ By default, `grpc` is used. To set `http/protobuf` as the protocol, configure th
|
|||||||
```
|
```
|
||||||
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf
|
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf
|
||||||
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces
|
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces
|
||||||
python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Instrumentation of FastAPI
|
## Instrumentation of FastAPI
|
||||||
@@ -74,7 +74,7 @@ OpenTelemetry allows automatic instrumentation of FastAPI.
|
|||||||
|
|
||||||
1. Run vLLM with `opentelemetry-instrument`
|
1. Run vLLM with `opentelemetry-instrument`
|
||||||
```
|
```
|
||||||
opentelemetry-instrument python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m"
|
opentelemetry-instrument vllm serve facebook/opt-125m
|
||||||
```
|
```
|
||||||
|
|
||||||
1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI.
|
1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI.
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ Install:
|
|||||||
|
|
||||||
Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint:
|
Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint:
|
||||||
```bash
|
```bash
|
||||||
python3 -m vllm.entrypoints.openai.api_server \
|
vllm serve mistralai/Mistral-7B-v0.1 \
|
||||||
--model mistralai/Mistral-7B-v0.1 \
|
|
||||||
--max-model-len 2048 \
|
--max-model-len 2048 \
|
||||||
--disable-log-requests
|
--disable-log-requests
|
||||||
```
|
```
|
||||||
|
|||||||
49
examples/run_cluster.sh
Normal file
49
examples/run_cluster.sh
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check for minimum number of required arguments
|
||||||
|
if [ $# -lt 4 ]; then
|
||||||
|
echo "Usage: $0 docker_image head_node_address --head|--worker path_to_hf_home [additional_args...]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Assign the first three arguments and shift them away
|
||||||
|
DOCKER_IMAGE="$1"
|
||||||
|
HEAD_NODE_ADDRESS="$2"
|
||||||
|
NODE_TYPE="$3" # Should be --head or --worker
|
||||||
|
PATH_TO_HF_HOME="$4"
|
||||||
|
shift 4
|
||||||
|
|
||||||
|
# Additional arguments are passed directly to the Docker command
|
||||||
|
ADDITIONAL_ARGS="$@"
|
||||||
|
|
||||||
|
# Validate node type
|
||||||
|
if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||||
|
echo "Error: Node type must be --head or --worker"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Define a function to cleanup on EXIT signal
|
||||||
|
cleanup() {
|
||||||
|
docker stop node
|
||||||
|
docker rm node
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
# Command setup for head or worker node
|
||||||
|
RAY_START_CMD="ray start --block"
|
||||||
|
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||||
|
RAY_START_CMD+=" --head --port=6379"
|
||||||
|
else
|
||||||
|
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the docker command with the user specified parameters and additional arguments
|
||||||
|
docker run \
|
||||||
|
--entrypoint /bin/bash \
|
||||||
|
--network host \
|
||||||
|
--name node \
|
||||||
|
--shm-size 10.24g \
|
||||||
|
--gpus all \
|
||||||
|
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||||
|
${ADDITIONAL_ARGS} \
|
||||||
|
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
|
||||||
@@ -2,5 +2,9 @@
|
|||||||
-r requirements-common.txt
|
-r requirements-common.txt
|
||||||
|
|
||||||
# Dependencies for AMD GPUs
|
# Dependencies for AMD GPUs
|
||||||
|
awscli
|
||||||
|
boto3
|
||||||
|
botocore
|
||||||
ray >= 2.10.0
|
ray >= 2.10.0
|
||||||
|
peft
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
--- amd_hip_bf16.h 2024-02-06 18:28:58.268699142 +0000
|
|
||||||
+++ amd_hip_bf16.h.new 2024-02-06 18:28:31.988647133 +0000
|
|
||||||
@@ -90,10 +90,10 @@
|
|
||||||
#include "math_fwd.h" // ocml device functions
|
|
||||||
|
|
||||||
#if defined(__HIPCC_RTC__)
|
|
||||||
-#define __HOST_DEVICE__ __device__
|
|
||||||
+#define __HOST_DEVICE__ __device__ static
|
|
||||||
#else
|
|
||||||
#include <climits>
|
|
||||||
-#define __HOST_DEVICE__ __host__ __device__
|
|
||||||
+#define __HOST_DEVICE__ __host__ __device__ static inline
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on
|
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import load_chat_template
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
||||||
@@ -50,24 +49,9 @@ TEST_MESSAGES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockTokenizer:
|
|
||||||
chat_template = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockServingChat:
|
|
||||||
tokenizer: MockTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_chat_template():
|
def test_load_chat_template():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
tokenizer = MockTokenizer()
|
template_content = load_chat_template(chat_template=chatml_jinja_path)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
|
||||||
chat_template=chatml_jinja_path)
|
|
||||||
|
|
||||||
template_content = tokenizer.chat_template
|
|
||||||
|
|
||||||
# Test assertions
|
# Test assertions
|
||||||
assert template_content is not None
|
assert template_content is not None
|
||||||
@@ -79,24 +63,16 @@ def test_load_chat_template():
|
|||||||
def test_no_load_chat_template_filelike():
|
def test_no_load_chat_template_filelike():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "../../examples/does_not_exist"
|
template = "../../examples/does_not_exist"
|
||||||
tokenizer = MockTokenizer()
|
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="looks like a file path"):
|
with pytest.raises(ValueError, match="looks like a file path"):
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
load_chat_template(chat_template=template)
|
||||||
chat_template=template)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_load_chat_template_literallike():
|
def test_no_load_chat_template_literallike():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "{{ messages }}"
|
template = "{{ messages }}"
|
||||||
tokenizer = MockTokenizer()
|
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
template_content = load_chat_template(chat_template=template)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
|
||||||
chat_template=template)
|
|
||||||
template_content = tokenizer.chat_template
|
|
||||||
|
|
||||||
assert template_content == template
|
assert template_content == template
|
||||||
|
|
||||||
@@ -108,9 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
expected_output):
|
expected_output):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
template_content = load_chat_template(chat_template=template)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
|
||||||
chat_template=template)
|
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
@@ -122,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
result = tokenizer.apply_chat_template(
|
result = tokenizer.apply_chat_template(
|
||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=mock_request.add_generation_prompt)
|
add_generation_prompt=mock_request.add_generation_prompt,
|
||||||
|
chat_template=mock_request.chat_template or template_content)
|
||||||
|
|
||||||
# Test assertion
|
# Test assertion
|
||||||
assert result == expected_output, (
|
assert result == expected_output, (
|
||||||
|
|||||||
@@ -9,17 +9,17 @@ MODEL_NAME = "facebook/opt-125m"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
with RemoteOpenAIServer([
|
args = [
|
||||||
"--model",
|
# use half precision for speed and memory savings in CI environment
|
||||||
MODEL_NAME,
|
"--dtype",
|
||||||
# use half precision for speed and memory savings in CI environment
|
"float16",
|
||||||
"--dtype",
|
"--max-model-len",
|
||||||
"float16",
|
"2048",
|
||||||
"--max-model-len",
|
"--enforce-eager",
|
||||||
"2048",
|
"--engine-use-ray"
|
||||||
"--enforce-eager",
|
]
|
||||||
"--engine-use-ray"
|
|
||||||
]) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
13
tests/basic_correctness/test_cpu_offload.py
Normal file
13
tests/basic_correctness/test_cpu_offload.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
from ..utils import compare_two_settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_offload():
|
||||||
|
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
|
||||||
|
["--cpu-offload-gb", "4"])
|
||||||
|
if not is_hip():
|
||||||
|
# compressed-tensors quantization is currently not supported in ROCm.
|
||||||
|
compare_two_settings(
|
||||||
|
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
|
||||||
|
["--cpu-offload-gb", "1"])
|
||||||
@@ -3,11 +3,7 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import UserList
|
from collections import UserList
|
||||||
from dataclasses import dataclass
|
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
|
||||||
from functools import cached_property
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
|
|
||||||
TypeVar)
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -18,14 +14,16 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
|||||||
AutoTokenizer, BatchEncoding)
|
AutoTokenizer, BatchEncoding)
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
|
from vllm.connections import global_http_connection
|
||||||
from vllm.distributed import (destroy_distributed_environment,
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
destroy_model_parallel)
|
destroy_model_parallel)
|
||||||
from vllm.inputs import TextPrompt
|
from vllm.inputs import TextPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal.utils import fetch_image
|
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.utils import cuda_device_count_stateless, is_cpu
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||||
|
is_cpu)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -33,9 +31,6 @@ _TEST_DIR = os.path.dirname(__file__)
|
|||||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||||
|
|
||||||
_IMAGE_DIR = Path(_TEST_DIR) / "images"
|
|
||||||
"""You can use `.buildkite/download-images.sh` to download the assets."""
|
|
||||||
|
|
||||||
|
|
||||||
def _read_prompts(filename: str) -> List[str]:
|
def _read_prompts(filename: str) -> List[str]:
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
@@ -43,24 +38,9 @@ def _read_prompts(filename: str) -> List[str]:
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ImageAsset:
|
|
||||||
name: Literal["stop_sign", "cherry_blossom", "boardwalk"]
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def pil_image(self) -> Image.Image:
|
|
||||||
if self.name == "boardwalk":
|
|
||||||
return fetch_image(
|
|
||||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
|
||||||
)
|
|
||||||
|
|
||||||
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
|
|
||||||
|
|
||||||
|
|
||||||
class _ImageAssetPrompts(TypedDict):
|
class _ImageAssetPrompts(TypedDict):
|
||||||
stop_sign: str
|
stop_sign: str
|
||||||
cherry_blossom: str
|
cherry_blossom: str
|
||||||
boardwalk: str
|
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info < (3, 9):
|
if sys.version_info < (3, 9):
|
||||||
@@ -79,7 +59,6 @@ class _ImageAssets(_ImageAssetsBase):
|
|||||||
super().__init__([
|
super().__init__([
|
||||||
ImageAsset("stop_sign"),
|
ImageAsset("stop_sign"),
|
||||||
ImageAsset("cherry_blossom"),
|
ImageAsset("cherry_blossom"),
|
||||||
ImageAsset("boardwalk")
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
|
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
|
||||||
@@ -89,16 +68,20 @@ class _ImageAssets(_ImageAssetsBase):
|
|||||||
The order of the returned prompts matches the order of the
|
The order of the returned prompts matches the order of the
|
||||||
assets when iterating through this object.
|
assets when iterating through this object.
|
||||||
"""
|
"""
|
||||||
return [
|
return [prompts["stop_sign"], prompts["cherry_blossom"]]
|
||||||
prompts["stop_sign"], prompts["cherry_blossom"],
|
|
||||||
prompts["boardwalk"]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
IMAGE_ASSETS = _ImageAssets()
|
IMAGE_ASSETS = _ImageAssets()
|
||||||
"""Singleton instance of :class:`_ImageAssets`."""
|
"""Singleton instance of :class:`_ImageAssets`."""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def init_test_http_connection():
|
||||||
|
# pytest_asyncio may use a different event loop per test
|
||||||
|
# so we need to make sure the async client is created anew
|
||||||
|
global_http_connection.reuse_client = False
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
destroy_distributed_environment()
|
destroy_distributed_environment()
|
||||||
@@ -150,12 +133,6 @@ def image_assets() -> _ImageAssets:
|
|||||||
return IMAGE_ASSETS
|
return IMAGE_ASSETS
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
||||||
"half": torch.half,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
"float": torch.float,
|
|
||||||
}
|
|
||||||
|
|
||||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
|
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,8 +154,7 @@ class HfRunner:
|
|||||||
is_vision_model: bool = False,
|
is_vision_model: bool = False,
|
||||||
is_sparseml_model: bool = False,
|
is_sparseml_model: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@@ -590,6 +566,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
|
|||||||
return TokenizerPoolConfig(pool_size=1,
|
return TokenizerPoolConfig(pool_size=1,
|
||||||
pool_type="ray",
|
pool_type="ray",
|
||||||
extra_config={})
|
extra_config={})
|
||||||
|
if isinstance(tokenizer_group_type, type):
|
||||||
|
return TokenizerPoolConfig(pool_size=1,
|
||||||
|
pool_type=tokenizer_group_type,
|
||||||
|
extra_config={})
|
||||||
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -249,10 +249,13 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
|
|||||||
|
|
||||||
# Expect consumed blocks to be new blocks required to support the new slots.
|
# Expect consumed blocks to be new blocks required to support the new slots.
|
||||||
expected_consumed_blocks = len(
|
expected_consumed_blocks = len(
|
||||||
chunk_list(
|
list(
|
||||||
list(
|
chunk_list(
|
||||||
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
|
list(
|
||||||
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
|
range(prompt_len + num_slots_to_append +
|
||||||
|
num_lookahead_slots)),
|
||||||
|
block_size))) - len(
|
||||||
|
list(chunk_list(list(range(prompt_len)), block_size)))
|
||||||
assert num_consumed_blocks == expected_consumed_blocks
|
assert num_consumed_blocks == expected_consumed_blocks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -58,10 +58,10 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
|
|||||||
|
|
||||||
unique_token_ids = list(
|
unique_token_ids = list(
|
||||||
range((num_cpu_blocks + num_gpu_blocks) * block_size))
|
range((num_cpu_blocks + num_gpu_blocks) * block_size))
|
||||||
gpu_token_ids = chunk_list(unique_token_ids[:num_gpu_blocks * block_size],
|
gpu_token_ids = list(
|
||||||
block_size)
|
chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size))
|
||||||
cpu_token_ids = chunk_list(unique_token_ids[num_gpu_blocks * block_size:],
|
cpu_token_ids = list(
|
||||||
block_size)
|
chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size))
|
||||||
|
|
||||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||||
|
|||||||
@@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
|
|||||||
lora_request=LoRARequest(
|
lora_request=LoRARequest(
|
||||||
lora_name=str(i),
|
lora_name=str(i),
|
||||||
lora_int_id=i + 1,
|
lora_int_id=i + 1,
|
||||||
lora_local_path="abc"))
|
lora_path="abc"))
|
||||||
waiting.append(seq_group)
|
waiting.append(seq_group)
|
||||||
# Add two more requests to verify lora is prioritized.
|
# Add two more requests to verify lora is prioritized.
|
||||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||||
@@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
|
|||||||
lora_request=LoRARequest(
|
lora_request=LoRARequest(
|
||||||
lora_name=str(i),
|
lora_name=str(i),
|
||||||
lora_int_id=i + 1,
|
lora_int_id=i + 1,
|
||||||
lora_local_path="abc"))
|
lora_path="abc"))
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||||
|
|||||||
@@ -1,140 +1,63 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import openai # use the official client for correctness check
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..utils import RemoteOpenAIServer
|
from ..utils import compare_two_settings
|
||||||
|
|
||||||
# downloading lora to test lora requests
|
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
|
|
||||||
# any model with a chat template should work here
|
|
||||||
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
|
|
||||||
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
|
|
||||||
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
|
|
||||||
TP_SIZE = int(os.getenv("TP_SIZE", 1))
|
|
||||||
PP_SIZE = int(os.getenv("PP_SIZE", 1))
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.mark.parametrize(
|
||||||
def server():
|
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND",
|
||||||
args = [
|
[
|
||||||
"--model",
|
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||||
MODEL_NAME,
|
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||||
|
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||||
|
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||||
|
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
|
||||||
|
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
|
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
|
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
|
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
|
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
|
])
|
||||||
|
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||||
|
DIST_BACKEND):
|
||||||
|
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
|
||||||
|
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend")
|
||||||
|
|
||||||
|
pp_args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
"float16",
|
||||||
"--pipeline-parallel-size",
|
"--pipeline-parallel-size",
|
||||||
str(PP_SIZE),
|
str(PP_SIZE),
|
||||||
"--tensor-parallel-size",
|
"--tensor-parallel-size",
|
||||||
str(TP_SIZE),
|
str(TP_SIZE),
|
||||||
"--distributed-executor-backend",
|
"--distributed-executor-backend",
|
||||||
"ray",
|
DIST_BACKEND,
|
||||||
|
]
|
||||||
|
|
||||||
|
# compare without pipeline parallelism
|
||||||
|
# NOTE: use mp backend for TP
|
||||||
|
# PP tests might involve multiple nodes, and ray might
|
||||||
|
# schedule all workers in a node other than the head node,
|
||||||
|
# which can cause the test to fail.
|
||||||
|
tp_args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
"mp",
|
||||||
]
|
]
|
||||||
if CHUNKED_PREFILL:
|
if CHUNKED_PREFILL:
|
||||||
args += [
|
pp_args.append("--enable-chunked-prefill")
|
||||||
"--enable-chunked-prefill",
|
tp_args.append("--enable-chunked-prefill")
|
||||||
]
|
|
||||||
if EAGER_MODE:
|
if EAGER_MODE:
|
||||||
args += [
|
pp_args.append("--enforce-eager")
|
||||||
"--enforce-eager",
|
tp_args.append("--enforce-eager")
|
||||||
]
|
|
||||||
with RemoteOpenAIServer(args) as remote_server:
|
|
||||||
yield remote_server
|
|
||||||
|
|
||||||
|
compare_two_settings(MODEL_NAME, pp_args, tp_args)
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def client(server):
|
|
||||||
return server.get_async_client()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
|
||||||
models = await client.models.list()
|
|
||||||
models = models.data
|
|
||||||
served_model = models[0]
|
|
||||||
assert served_model.id == MODEL_NAME
|
|
||||||
assert all(model.root == MODEL_NAME for model in models)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_single_completion(server, client: openai.AsyncOpenAI,
|
|
||||||
model_name: str):
|
|
||||||
completion = await client.completions.create(model=model_name,
|
|
||||||
prompt="Hello, my name is",
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0)
|
|
||||||
|
|
||||||
assert completion.id is not None
|
|
||||||
assert completion.choices is not None and len(completion.choices) == 1
|
|
||||||
assert completion.choices[0].text is not None and len(
|
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
assert completion.choices[0].finish_reason == "length"
|
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
|
||||||
|
|
||||||
# test using token IDs
|
|
||||||
completion = await client.completions.create(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
prompt=[0, 0, 0, 0, 0],
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
assert completion.choices[0].text is not None and len(
|
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
# just test 1 lora hereafter
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
|
||||||
model_name: str):
|
|
||||||
# test simple list
|
|
||||||
batch = await client.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt=["Hello, my name is", "Hello, my name is"],
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
assert len(batch.choices) == 2
|
|
||||||
assert batch.choices[0].text == batch.choices[1].text
|
|
||||||
|
|
||||||
# test n = 2
|
|
||||||
batch = await client.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt=["Hello, my name is", "Hello, my name is"],
|
|
||||||
n=2,
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body=dict(
|
|
||||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
|
||||||
# for official client.
|
|
||||||
use_beam_search=True),
|
|
||||||
)
|
|
||||||
assert len(batch.choices) == 4
|
|
||||||
assert batch.choices[0].text != batch.choices[
|
|
||||||
1].text, "beam search should be different"
|
|
||||||
assert batch.choices[0].text == batch.choices[
|
|
||||||
2].text, "two copies of the same prompt should be the same"
|
|
||||||
assert batch.choices[1].text == batch.choices[
|
|
||||||
3].text, "two copies of the same prompt should be the same"
|
|
||||||
|
|
||||||
# test streaming
|
|
||||||
batch = await client.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt=["Hello, my name is", "Hello, my name is"],
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
texts = [""] * 2
|
|
||||||
async for chunk in batch:
|
|
||||||
assert len(chunk.choices) == 1
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
texts[choice.index] += choice.text
|
|
||||||
assert texts[0] == texts[1]
|
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
|
|||||||
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
|
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
|
||||||
("This text ends with EOS token", "</s>", 2),
|
("This text ends with EOS token", "</s>", 2),
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("ignore_eos", [True, False, None])
|
@pytest.mark.parametrize("ignore_eos", [True, False])
|
||||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
|
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
||||||
ignore_eos: bool, include_stop_str_in_output: bool):
|
ignore_eos: bool, include_stop_str_in_output: bool):
|
||||||
|
|||||||
91
tests/engine/test_custom_executor.py
Normal file
91
tests/engine/test_custom_executor.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
class Mock:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGPUExecutor(GPUExecutor):
|
||||||
|
|
||||||
|
def execute_model(self, *args, **kwargs):
|
||||||
|
# Drop marker to show that this was ran
|
||||||
|
with open(".marker", "w"):
|
||||||
|
...
|
||||||
|
return super().execute_model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGPUExecutorAsync(GPUExecutorAsync):
|
||||||
|
|
||||||
|
async def execute_model_async(self, *args, **kwargs):
|
||||||
|
with open(".marker", "w"):
|
||||||
|
...
|
||||||
|
return await super().execute_model_async(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_custom_executor_type_checking(model):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
engine_args = EngineArgs(model=model,
|
||||||
|
distributed_executor_backend=Mock)
|
||||||
|
LLMEngine.from_engine_args(engine_args)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
engine_args = AsyncEngineArgs(model=model,
|
||||||
|
distributed_executor_backend=Mock)
|
||||||
|
AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model, distributed_executor_backend=CustomGPUExecutor)
|
||||||
|
AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_custom_executor(model, tmpdir):
|
||||||
|
cwd = os.path.abspath(".")
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
try:
|
||||||
|
assert not os.path.exists(".marker")
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model, distributed_executor_backend=CustomGPUExecutor)
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
|
||||||
|
engine.add_request("0", "foo", sampling_params)
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
assert os.path.exists(".marker")
|
||||||
|
finally:
|
||||||
|
os.chdir(cwd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_custom_executor_async(model, tmpdir):
|
||||||
|
cwd = os.path.abspath(".")
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
try:
|
||||||
|
assert not os.path.exists(".marker")
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
|
||||||
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
|
||||||
|
async def t():
|
||||||
|
stream = await engine.add_request("0", "foo", sampling_params)
|
||||||
|
async for x in stream:
|
||||||
|
...
|
||||||
|
|
||||||
|
asyncio.run(t())
|
||||||
|
|
||||||
|
assert os.path.exists(".marker")
|
||||||
|
finally:
|
||||||
|
os.chdir(cwd)
|
||||||
61
tests/entrypoints/openai/test_basic.py
Normal file
61
tests/entrypoints/openai/test_basic.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client(server):
|
||||||
|
return server.get_async_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_show_version(client: openai.AsyncOpenAI):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
|
||||||
|
response = requests.get(base_url + "/version")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {"version": VLLM_VERSION}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_health(client: openai.AsyncOpenAI):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
|
||||||
|
response = requests.get(base_url + "/health")
|
||||||
|
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_metrics(client: openai.AsyncOpenAI):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
|
||||||
|
response = requests.get(base_url + "/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
@@ -7,11 +7,11 @@ import jsonschema
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
# downloading lora to test lora requests
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||||
|
from .test_completion import zephyr_lora_files # noqa: F401
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
@@ -21,33 +21,28 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def zephyr_lora_files():
|
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
|
||||||
return snapshot_download(repo_id=LORA_NAME)
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
# lora config below
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
"--max-cpu-loras",
|
||||||
|
"2",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def server(zephyr_lora_files):
|
|
||||||
with RemoteOpenAIServer([
|
|
||||||
"--model",
|
|
||||||
MODEL_NAME,
|
|
||||||
# use half precision for speed and memory savings in CI environment
|
|
||||||
"--dtype",
|
|
||||||
"bfloat16",
|
|
||||||
"--max-model-len",
|
|
||||||
"8192",
|
|
||||||
"--enforce-eager",
|
|
||||||
# lora config below
|
|
||||||
"--enable-lora",
|
|
||||||
"--lora-modules",
|
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
|
||||||
"--max-lora-rank",
|
|
||||||
"64",
|
|
||||||
"--max-cpu-loras",
|
|
||||||
"2",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
|
||||||
]) as remote_server:
|
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
# imports for guided decoding tests
|
# imports for guided decoding tests
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@@ -17,9 +19,13 @@ from ...utils import RemoteOpenAIServer
|
|||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
# technically these adapters use a different base model,
|
||||||
# generation quality here
|
# but we're not testing generation quality here
|
||||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||||
|
PA_NAME = "swapnilbp/llama_tweet_ptune"
|
||||||
|
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
|
||||||
|
# need to change to match the prompt adapter
|
||||||
|
PA_NUM_VIRTUAL_TOKENS = 8
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@@ -28,28 +34,58 @@ def zephyr_lora_files():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server(zephyr_lora_files):
|
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||||
with RemoteOpenAIServer([
|
tmp_dir = TemporaryDirectory()
|
||||||
"--model",
|
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||||
MODEL_NAME,
|
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||||
# use half precision for speed and memory savings in CI environment
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
"--dtype",
|
# Copy tokenizer to adapter and add some unique tokens
|
||||||
"bfloat16",
|
# 32000, 32001, 32002
|
||||||
"--max-model-len",
|
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||||
"8192",
|
special_tokens=True)
|
||||||
"--enforce-eager",
|
assert added == 3
|
||||||
# lora config below
|
tokenizer.save_pretrained(tmp_model_dir)
|
||||||
"--enable-lora",
|
yield tmp_model_dir
|
||||||
"--lora-modules",
|
tmp_dir.cleanup()
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
|
||||||
"--max-lora-rank",
|
@pytest.fixture(scope="module")
|
||||||
"64",
|
def zephyr_pa_files():
|
||||||
"--max-cpu-loras",
|
return snapshot_download(repo_id=PA_NAME)
|
||||||
"2",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
@pytest.fixture(scope="module")
|
||||||
]) as remote_server:
|
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--enforce-eager",
|
||||||
|
# lora config
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
"--max-cpu-loras",
|
||||||
|
"2",
|
||||||
|
# pa config
|
||||||
|
"--enable-prompt-adapter",
|
||||||
|
"--prompt-adapters",
|
||||||
|
f"zephyr-pa={zephyr_pa_files}",
|
||||||
|
f"zephyr-pa2={zephyr_pa_files}",
|
||||||
|
"--max-prompt-adapters",
|
||||||
|
"2",
|
||||||
|
"--max-prompt-adapter-token",
|
||||||
|
"128",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@@ -60,11 +96,14 @@ def client(server):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# first test base model, then test loras
|
# first test base model, then test loras, then test prompt adapters
|
||||||
"model_name",
|
"model_name,num_virtual_tokens",
|
||||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
|
||||||
|
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
|
||||||
|
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
|
||||||
)
|
)
|
||||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
num_virtual_tokens: int):
|
||||||
completion = await client.completions.create(model=model_name,
|
completion = await client.completions.create(model=model_name,
|
||||||
prompt="Hello, my name is",
|
prompt="Hello, my name is",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
@@ -77,28 +116,58 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
assert len(choice.text) >= 5
|
assert len(choice.text) >= 5
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
completion_tokens=5,
|
||||||
|
prompt_tokens=6 + num_virtual_tokens,
|
||||||
|
total_tokens=11 + num_virtual_tokens)
|
||||||
|
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=model_name,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=[0, 0, 0, 0, 0],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
assert len(completion.choices[0].text) >= 5
|
assert len(completion.choices[0].text) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model="zephyr-lora2",
|
||||||
|
prompt=[0, 0, 32000, 32001, 32002],
|
||||||
|
echo=True,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
# Added tokens should appear in tokenized prompt
|
||||||
|
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 32000, 32001, 32002],
|
||||||
|
echo=True,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
# Added tokens should not appear in tokenized prompt
|
||||||
|
assert "vllm" not in completion.choices[0].text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# first test base model, then test loras
|
# first test base model, then test loras, then test prompt adapters
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
|
||||||
)
|
)
|
||||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=model_name,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=[0, 0, 0, 0, 0],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
@@ -110,14 +179,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# just test 1 lora hereafter
|
# just test 1 lora and 1 pa hereafter
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
)
|
)
|
||||||
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=model_name,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=[0, 0, 0, 0, 0],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
@@ -133,12 +202,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
)
|
)
|
||||||
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=model_name,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=[0, 0, 0, 0, 0],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
@@ -154,7 +223,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
)
|
)
|
||||||
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
@@ -162,7 +231,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||||
await client.completions.create(
|
await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=model_name,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=[0, 0, 0, 0, 0],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
@@ -174,7 +243,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||||
stream = await client.completions.create(
|
stream = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=model_name,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=[0, 0, 0, 0, 0],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
@@ -199,7 +268,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
)
|
)
|
||||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
@@ -233,7 +302,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
)
|
)
|
||||||
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
@@ -369,9 +438,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# just test 1 lora hereafter
|
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
)
|
)
|
||||||
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||||
# test both text and token IDs
|
# test both text and token IDs
|
||||||
@@ -614,51 +682,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
|||||||
prompt="Give an example string that fits this regex",
|
prompt="Give an example string that fits this regex",
|
||||||
extra_body=dict(guided_regex=sample_regex,
|
extra_body=dict(guided_regex=sample_regex,
|
||||||
guided_json=sample_json_schema))
|
guided_json=sample_json_schema))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
|
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
|
||||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
|
|
||||||
|
|
||||||
for add_special in [False, True]:
|
|
||||||
prompt = "This is a test prompt."
|
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
||||||
|
|
||||||
response = requests.post(base_url + "/tokenize",
|
|
||||||
json={
|
|
||||||
"add_special_tokens": add_special,
|
|
||||||
"model": model_name,
|
|
||||||
"prompt": prompt
|
|
||||||
})
|
|
||||||
response.raise_for_status()
|
|
||||||
assert response.json() == {
|
|
||||||
"tokens": tokens,
|
|
||||||
"count": len(tokens),
|
|
||||||
"max_model_len": 8192
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
|
|
||||||
base_url = str(client.base_url)[:-3]
|
|
||||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
|
|
||||||
|
|
||||||
prompt = "This is a test prompt."
|
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
|
||||||
|
|
||||||
response = requests.post(base_url + "detokenize",
|
|
||||||
json={
|
|
||||||
"model": model_name,
|
|
||||||
"tokens": tokens
|
|
||||||
})
|
|
||||||
response.raise_for_status()
|
|
||||||
assert response.json() == {"prompt": prompt}
|
|
||||||
|
|||||||
@@ -11,17 +11,17 @@ EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def embedding_server():
|
def embedding_server():
|
||||||
with RemoteOpenAIServer([
|
args = [
|
||||||
"--model",
|
# use half precision for speed and memory savings in CI environment
|
||||||
EMBEDDING_MODEL_NAME,
|
"--dtype",
|
||||||
# use half precision for speed and memory savings in CI environment
|
"bfloat16",
|
||||||
"--dtype",
|
"--enforce-eager",
|
||||||
"bfloat16",
|
"--max-model-len",
|
||||||
"--enforce-eager",
|
"8192",
|
||||||
"--max-model-len",
|
"--enforce-eager",
|
||||||
"8192",
|
]
|
||||||
"--enforce-eager",
|
|
||||||
]) as remote_server:
|
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,27 +19,27 @@ def zephyr_lora_files():
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server(zephyr_lora_files):
|
def server(zephyr_lora_files):
|
||||||
with RemoteOpenAIServer([
|
args = [
|
||||||
"--model",
|
# use half precision for speed and memory savings in CI environment
|
||||||
MODEL_NAME,
|
"--dtype",
|
||||||
# use half precision for speed and memory savings in CI environment
|
"bfloat16",
|
||||||
"--dtype",
|
"--max-model-len",
|
||||||
"bfloat16",
|
"8192",
|
||||||
"--max-model-len",
|
"--enforce-eager",
|
||||||
"8192",
|
# lora config below
|
||||||
"--enforce-eager",
|
"--enable-lora",
|
||||||
# lora config below
|
"--lora-modules",
|
||||||
"--enable-lora",
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
"--lora-modules",
|
f"zephyr-lora2={zephyr_lora_files}",
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
"--max-lora-rank",
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
"64",
|
||||||
"--max-lora-rank",
|
"--max-cpu-loras",
|
||||||
"64",
|
"2",
|
||||||
"--max-cpu-loras",
|
"--max-num-seqs",
|
||||||
"2",
|
"128",
|
||||||
"--max-num-seqs",
|
]
|
||||||
"128",
|
|
||||||
]) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,11 +32,13 @@ async def _async_serving_chat_init():
|
|||||||
model_config,
|
model_config,
|
||||||
served_model_names=[MODEL_NAME],
|
served_model_names=[MODEL_NAME],
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE)
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
lora_modules=None,
|
||||||
|
prompt_adapters=None,
|
||||||
|
request_logger=None)
|
||||||
return serving_completion
|
return serving_completion
|
||||||
|
|
||||||
|
|
||||||
def test_async_serving_chat_init():
|
def test_async_serving_chat_init():
|
||||||
serving_completion = asyncio.run(_async_serving_chat_init())
|
serving_completion = asyncio.run(_async_serving_chat_init())
|
||||||
assert serving_completion.tokenizer is not None
|
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||||
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
|
|
||||||
|
|||||||
152
tests/entrypoints/openai/test_tokenization.py
Normal file
152
tests/entrypoints/openai/test_tokenization.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||||
|
from .test_completion import zephyr_lora_files # noqa: F401
|
||||||
|
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
# lora config
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tokenizer_name(model_name: str,
|
||||||
|
zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||||
|
return zephyr_lora_added_tokens_files if (
|
||||||
|
model_name == "zephyr-lora2") else model_name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client(server):
|
||||||
|
return server.get_async_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name,tokenizer_name",
|
||||||
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
|
)
|
||||||
|
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, tokenizer_name: str):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
|
for add_special in [False, True]:
|
||||||
|
prompt = "vllm1 This is a test prompt."
|
||||||
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||||
|
|
||||||
|
response = requests.post(base_url + "/tokenize",
|
||||||
|
json={
|
||||||
|
"add_special_tokens": add_special,
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {
|
||||||
|
"tokens": tokens,
|
||||||
|
"count": len(tokens),
|
||||||
|
"max_model_len": 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name,tokenizer_name",
|
||||||
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
|
)
|
||||||
|
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
tokenizer_name: str):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
|
for add_generation in [False, True]:
|
||||||
|
for add_special in [False, True]:
|
||||||
|
conversation = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi there!"
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Nice to meet you!"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Can I ask a question? vllm1"
|
||||||
|
}]
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
add_generation_prompt=add_generation,
|
||||||
|
conversation=conversation,
|
||||||
|
tokenize=False)
|
||||||
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||||
|
|
||||||
|
response = requests.post(base_url + "/tokenize",
|
||||||
|
json={
|
||||||
|
"add_generation_prompt":
|
||||||
|
add_generation,
|
||||||
|
"add_special_tokens": add_special,
|
||||||
|
"messages": conversation,
|
||||||
|
"model": model_name
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {
|
||||||
|
"tokens": tokens,
|
||||||
|
"count": len(tokens),
|
||||||
|
"max_model_len": 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name,tokenizer_name",
|
||||||
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
|
)
|
||||||
|
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
tokenizer_name: str):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
|
prompt = "This is a test prompt. vllm1"
|
||||||
|
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
print(f"CALLING {base_url} FOR {model_name}")
|
||||||
|
response = requests.post(base_url + "/detokenize",
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"tokens": tokens
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {"prompt": prompt}
|
||||||
@@ -2,9 +2,8 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
|
||||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
|
|
||||||
@@ -23,17 +22,17 @@ TEST_IMAGE_URLS = [
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
with RemoteOpenAIServer([
|
args = [
|
||||||
"--model",
|
"--dtype",
|
||||||
MODEL_NAME,
|
"bfloat16",
|
||||||
"--dtype",
|
"--max-model-len",
|
||||||
"bfloat16",
|
"4096",
|
||||||
"--max-model-len",
|
"--enforce-eager",
|
||||||
"4096",
|
"--chat-template",
|
||||||
"--enforce-eager",
|
str(LLAVA_CHAT_TEMPLATE),
|
||||||
"--chat-template",
|
]
|
||||||
str(LLAVA_CHAT_TEMPLATE),
|
|
||||||
]) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@@ -42,11 +41,10 @@ def client(server):
|
|||||||
return server.get_async_client()
|
return server.get_async_client()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
async def base64_encoded_image() -> Dict[str, str]:
|
def base64_encoded_image() -> Dict[str, str]:
|
||||||
return {
|
return {
|
||||||
image_url:
|
image_url: encode_image_base64(fetch_image(image_url))
|
||||||
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
|
|
||||||
for image_url in TEST_IMAGE_URLS
|
for image_url in TEST_IMAGE_URLS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
72
tests/kernels/quant_utils.py
Normal file
72
tests/kernels/quant_utils.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||||
|
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||||
|
|
||||||
|
def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
scale_ub: Optional[torch.tensor] = None) \
|
||||||
|
-> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
|
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
|
||||||
|
if scale_ub is not None:
|
||||||
|
assert quant_dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||||
|
else torch.finfo(quant_dtype)
|
||||||
|
qtype_max = as_float32_tensor(qtype_traits.max)
|
||||||
|
s_1 = as_float32_tensor(1.0)
|
||||||
|
s_512 = as_float32_tensor(512.0)
|
||||||
|
|
||||||
|
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||||
|
# the same operations as in the corresponding fp8 kernel to prevent
|
||||||
|
# rounding errors.
|
||||||
|
|
||||||
|
# Compute scales
|
||||||
|
x_token_max, _ = x.abs().max(dim=-1)
|
||||||
|
x_token_max = as_float32_tensor(x_token_max)
|
||||||
|
if scale_ub is not None:
|
||||||
|
x_token_max = x_token_max.clamp(max=scale_ub)
|
||||||
|
scales = (x_token_max / qtype_max)[:, None]
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
if quant_dtype == torch.int8:
|
||||||
|
iscales = as_float32_tensor(s_1 / scales)
|
||||||
|
torch_out = as_float32_tensor(x) * iscales
|
||||||
|
torch_out = torch_out.round()
|
||||||
|
torch_out = torch_out.clamp(qtype_traits.min,
|
||||||
|
qtype_traits.max).to(quant_dtype)
|
||||||
|
else:
|
||||||
|
assert quant_dtype == torch.float8_e4m3fn
|
||||||
|
min_scaling_factor = s_1 / (qtype_max * s_512)
|
||||||
|
scales = scales.clamp(min=min_scaling_factor)
|
||||||
|
torch_out = as_float32_tensor(x) / scales
|
||||||
|
torch_out = torch_out.clamp(qtype_traits.min,
|
||||||
|
qtype_traits.max).to(quant_dtype)
|
||||||
|
|
||||||
|
return torch_out, scales
|
||||||
|
|
||||||
|
|
||||||
|
# The int8 version is very similar. Incorporate the int8 version, like in
|
||||||
|
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||||
|
# kernel
|
||||||
|
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||||
|
-> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
|
fp8_traits = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max = as_float32_tensor(fp8_traits.max)
|
||||||
|
one = as_float32_tensor(1.0)
|
||||||
|
|
||||||
|
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||||
|
# the same operations as in the corresponding fp8 kernel to prevent
|
||||||
|
# rounding errors.
|
||||||
|
|
||||||
|
x_max = as_float32_tensor(x.abs().max())
|
||||||
|
ref_scale = x_max / fp8_max
|
||||||
|
ref_iscale = one / ref_scale
|
||||||
|
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||||
|
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
||||||
|
return ref_out, ref_scale
|
||||||
@@ -175,7 +175,7 @@ def test_paged_attention(
|
|||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@@ -193,7 +193,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
@@ -224,7 +225,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Unknown version: {version}")
|
raise AssertionError(f"Unknown version: {version}")
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ def test_paged_attention(
|
|||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
tp_rank = 0
|
tp_rank = 0
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
@@ -231,7 +231,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
@@ -267,7 +268,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
|||||||
@@ -155,11 +155,11 @@ def test_reshape_and_cache(
|
|||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||||
kv_cache_dtype, kv_scale)
|
kv_cache_dtype, k_scale, v_scale)
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||||
|
|||||||
87
tests/kernels/test_fp8_quant.py
Normal file
87
tests/kernels/test_fp8_quant.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
|
||||||
|
ref_dynamic_per_token_quant)
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||||
|
8193] # Arbitrary values for testing
|
||||||
|
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
||||||
|
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||||
|
SCALE_UBS = [True, False]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||||
|
dtype: torch.dtype, scale_ub: bool,
|
||||||
|
seed: int) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||||
|
device="cuda") + 1e-6 # avoid nans
|
||||||
|
|
||||||
|
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
||||||
|
if scale_ub else None
|
||||||
|
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
|
||||||
|
scale_ub)
|
||||||
|
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||||
|
scale_ub=scale_ub,
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
|
assert torch.allclose(ref_scales, ops_scales)
|
||||||
|
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||||
|
ops_out.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||||
|
dtype: torch.dtype, seed: int) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||||
|
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||||
|
|
||||||
|
assert torch.allclose(ref_scale, ops_scale)
|
||||||
|
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||||
|
ops_out.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
# Regression test for a case with large activations where an int32 index cannot
|
||||||
|
# represent the number of elements.
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
def test_fp8_quant_large(seed: int) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
|
||||||
|
hidden_size = 1152 # Smallest hidden_size to reproduce the error
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||||
|
ops_out, _ = ops.scaled_fp8_quant(x, scale)
|
||||||
|
|
||||||
|
# Minimize memory footprint in this test by freeing x and upconverting
|
||||||
|
# the outputs in place. (torch.allclose does not support fp8)
|
||||||
|
del x
|
||||||
|
ref_out = ref_out.to(dtype=dtype)
|
||||||
|
ops_out = ops_out.to(dtype=dtype)
|
||||||
|
|
||||||
|
assert torch.allclose(ref_out, ops_out)
|
||||||
@@ -3,6 +3,8 @@ import torch
|
|||||||
|
|
||||||
# ruff: noqa: F401
|
# ruff: noqa: F401
|
||||||
import vllm._C
|
import vllm._C
|
||||||
|
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||||
|
from vllm._custom_ops import scaled_int8_quant
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||||
@@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
|||||||
dtype: torch.dtype, seed: int) -> None:
|
dtype: torch.dtype, seed: int) -> None:
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
int8_traits = torch.iinfo(torch.int8)
|
|
||||||
|
|
||||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||||
|
|
||||||
x_token_max, _ = x.max(dim=1)
|
# reference
|
||||||
x_token_max = x_token_max.to(dtype=torch.float32)
|
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
|
||||||
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
|
# kernel
|
||||||
dtype=torch.float32)
|
ops_out, ops_scales = scaled_int8_quant(x)
|
||||||
torch_out = (x / scales).round().clamp(int8_traits.min,
|
|
||||||
int8_traits.max).to(torch.int8)
|
|
||||||
|
|
||||||
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
|
assert torch.allclose(ops_scales, ref_scales)
|
||||||
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
|
assert torch.allclose(ops_out, ref_out,
|
||||||
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
|
|
||||||
|
|
||||||
assert torch.allclose(scales_out, scales)
|
|
||||||
assert torch.allclose(torch_out, ops_out,
|
|
||||||
atol=1) # big atol to account for rounding errors
|
atol=1) # big atol to account for rounding errors
|
||||||
|
|
||||||
|
|
||||||
@@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
|||||||
int8_traits = torch.iinfo(torch.int8)
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
|
||||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||||
|
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
out1 = (x / scale).round().clamp(int8_traits.min,
|
out1 = (x / scale).round().clamp(int8_traits.min,
|
||||||
int8_traits.max).to(torch.int8)
|
int8_traits.max).to(torch.int8)
|
||||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
out2, _ = scaled_int8_quant(x, scale)
|
||||||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
|
|
||||||
assert torch.allclose(out1, out2,
|
assert torch.allclose(out1, out2,
|
||||||
atol=1) # big atol to account for rounding errors
|
atol=1) # big atol to account for rounding errors
|
||||||
|
|||||||
@@ -12,16 +12,18 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
|||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
|
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
|
||||||
marlin_permute_scales)
|
marlin_make_empty_g_idx, marlin_permute_scales)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
pack_fp8_to_int32)
|
pack_fp8_to_int32)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights)
|
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||||
|
marlin_weights)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||||
marlin_24_quantize)
|
marlin_24_quantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
gptq_pack, quantize_weights, sort_weights)
|
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
|
||||||
|
sort_weights)
|
||||||
|
|
||||||
ACT_ORDER_OPTS = [False, True]
|
ACT_ORDER_OPTS = [False, True]
|
||||||
K_FULL_OPTS = [False, True]
|
K_FULL_OPTS = [False, True]
|
||||||
@@ -57,12 +59,12 @@ def rand_data(shape, dtype=torch.float16):
|
|||||||
reason="Marlin is not supported on this GPU type.")
|
reason="Marlin is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||||
mnk_factors):
|
mnk_factors):
|
||||||
m_factor, n_factor, k_factor = mnk_factors
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
size_m = m_factor
|
size_m = m_factor
|
||||||
@@ -120,12 +122,60 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
|||||||
reason="Marlin is not supported on this GPU type.")
|
reason="Marlin is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||||
|
mnk_factors):
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
|
size_m = m_factor
|
||||||
|
size_k = k_chunk * k_factor
|
||||||
|
size_n = n_chunk * n_factor
|
||||||
|
|
||||||
|
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||||
|
|
||||||
|
# Normalize group_size
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = size_k
|
||||||
|
assert group_size <= size_k
|
||||||
|
|
||||||
|
# Create input
|
||||||
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits,
|
||||||
|
group_size)
|
||||||
|
|
||||||
|
# Pack to AWQ format
|
||||||
|
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
# Pack to Marlin format
|
||||||
|
weight_perm = get_weight_perm(num_bits)
|
||||||
|
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||||
|
|
||||||
|
# Run Marlin repack GPU kernel
|
||||||
|
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||||
|
q_w_awq,
|
||||||
|
size_k,
|
||||||
|
size_n,
|
||||||
|
num_bits,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||||
def test_marlin_gemm(
|
def test_gptq_marlin_gemm(
|
||||||
k_chunk,
|
k_chunk,
|
||||||
n_chunk,
|
n_chunk,
|
||||||
num_bits,
|
num_bits,
|
||||||
@@ -155,6 +205,8 @@ def test_marlin_gemm(
|
|||||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||||
b_weight, num_bits, group_size, act_order)
|
b_weight, num_bits, group_size, act_order)
|
||||||
|
|
||||||
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||||
|
|
||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
@@ -162,6 +214,7 @@ def test_marlin_gemm(
|
|||||||
a_input,
|
a_input,
|
||||||
marlin_q_w,
|
marlin_q_w,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
|
marlin_zp,
|
||||||
g_idx,
|
g_idx,
|
||||||
sort_indices,
|
sort_indices,
|
||||||
workspace.scratch,
|
workspace.scratch,
|
||||||
@@ -170,6 +223,7 @@ def test_marlin_gemm(
|
|||||||
b_weight.shape[1],
|
b_weight.shape[1],
|
||||||
a_input.shape[1],
|
a_input.shape[1],
|
||||||
is_k_full,
|
is_k_full,
|
||||||
|
has_zp=False,
|
||||||
)
|
)
|
||||||
output_ref = torch.matmul(a_input, w_ref)
|
output_ref = torch.matmul(a_input, w_ref)
|
||||||
|
|
||||||
@@ -188,7 +242,8 @@ def test_marlin_gemm(
|
|||||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
|
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
|
||||||
|
mnk_factors):
|
||||||
m_factor, n_factor, k_factor = mnk_factors
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
size_m = m_factor
|
size_m = m_factor
|
||||||
@@ -301,3 +356,65 @@ def test_fp8_marlin_gemm(
|
|||||||
print("max_diff = {}".format(max_diff))
|
print("max_diff = {}".format(max_diff))
|
||||||
|
|
||||||
assert max_diff < 0.04
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
def test_awq_marlin_gemm(
|
||||||
|
k_chunk,
|
||||||
|
n_chunk,
|
||||||
|
num_bits,
|
||||||
|
group_size,
|
||||||
|
mnk_factors,
|
||||||
|
):
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
|
size_m = m_factor
|
||||||
|
size_k = k_chunk * k_factor
|
||||||
|
size_n = n_chunk * n_factor
|
||||||
|
|
||||||
|
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||||
|
print(f"groupsize = {group_size}")
|
||||||
|
|
||||||
|
a_input = rand_data((size_m, size_k))
|
||||||
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
|
||||||
|
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||||
|
b_weight, num_bits, group_size)
|
||||||
|
|
||||||
|
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||||
|
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||||
|
is_k_full = True
|
||||||
|
has_zp = True
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
output = ops.gptq_marlin_gemm(
|
||||||
|
a_input,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_s,
|
||||||
|
marlin_zp,
|
||||||
|
g_idx,
|
||||||
|
sort_indices,
|
||||||
|
workspace.scratch,
|
||||||
|
num_bits,
|
||||||
|
a_input.shape[0],
|
||||||
|
b_weight.shape[1],
|
||||||
|
a_input.shape[1],
|
||||||
|
is_k_full,
|
||||||
|
has_zp,
|
||||||
|
)
|
||||||
|
output_ref = torch.matmul(a_input, w_ref)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
max_diff = compute_max_diff(output, output_ref)
|
||||||
|
print("max_diff = {}".format(max_diff))
|
||||||
|
|
||||||
|
assert max_diff < 0.04
|
||||||
|
|||||||
@@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sql_lora_files():
|
def sql_lora_huggingface_id():
|
||||||
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
# huggingface repo id is used to test lora runtime downloading.
|
||||||
|
return "yard1/llama-2-7b-sql-lora-test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sql_lora_files(sql_lora_huggingface_id):
|
||||||
|
return snapshot_download(repo_id=sql_lora_huggingface_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def _create_lora_request(lora_id, long_context_infos):
|
|||||||
context_len = long_context_infos[lora_id]["context_length"]
|
context_len = long_context_infos[lora_id]["context_length"]
|
||||||
scaling_factor = context_len_to_scaling_factor[context_len]
|
scaling_factor = context_len_to_scaling_factor[context_len]
|
||||||
return LoRARequest(context_len, lora_id,
|
return LoRARequest(context_len, lora_id,
|
||||||
long_context_infos[lora_id]["lora"],
|
long_context_infos[lora_id]["lora"], None,
|
||||||
4096 * scaling_factor)
|
4096 * scaling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
39
tests/lora/test_lora_huggingface.py
Normal file
39
tests/lora/test_lora_huggingface.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.lora.models import LoRAModel
|
||||||
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
|
|
||||||
|
# Provide absolute path and huggingface lora ids
|
||||||
|
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
|
||||||
|
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
||||||
|
lora_name = request.getfixturevalue(lora_fixture_name)
|
||||||
|
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
|
||||||
|
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
|
||||||
|
embedding_modules = LlamaForCausalLM.embedding_modules
|
||||||
|
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
|
||||||
|
expected_lora_modules: List[str] = []
|
||||||
|
for module in supported_lora_modules:
|
||||||
|
if module in packed_modules_mapping:
|
||||||
|
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||||
|
else:
|
||||||
|
expected_lora_modules.append(module)
|
||||||
|
|
||||||
|
lora_path = get_adapter_absolute_path(lora_name)
|
||||||
|
|
||||||
|
# lora loading should work for either absolute path and hugggingface id.
|
||||||
|
lora_model = LoRAModel.from_local_checkpoint(
|
||||||
|
lora_path,
|
||||||
|
expected_lora_modules,
|
||||||
|
lora_model_id=1,
|
||||||
|
device="cpu",
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embed_padding_modules)
|
||||||
|
|
||||||
|
# Assertions to ensure the model is loaded correctly
|
||||||
|
assert lora_model is not None, "LoRAModel is not loaded correctly"
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user