Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1af090b57d | ||
|
|
3dad944485 | ||
|
|
105a40f53a | ||
|
|
bbe9bd9684 | ||
|
|
4f65af0e25 | ||
|
|
d79ced3292 | ||
|
|
ab40644669 | ||
|
|
5d60def02c | ||
|
|
ea8489fce2 | ||
|
|
1b20639a43 | ||
|
|
b72af8f1ed | ||
|
|
9090bf02e7 | ||
|
|
7d648418b8 | ||
|
|
89be30fa7d | ||
|
|
f8ecb84c02 | ||
|
|
5f036d2bcc | ||
|
|
380170038e | ||
|
|
220a47627b | ||
|
|
beb89f68b4 | ||
|
|
390b495ff3 | ||
|
|
3a0e1fc070 | ||
|
|
6b7de1a030 | ||
|
|
5265631d15 | ||
|
|
2832e7b9f9 | ||
|
|
3a7dd7e367 | ||
|
|
223c19224b | ||
|
|
f1f6cc10c7 | ||
|
|
3209b49033 | ||
|
|
1e4277d2d1 | ||
|
|
9b945daaf1 | ||
|
|
9c1352eb57 | ||
|
|
7a0b011dd5 | ||
|
|
63e835cbcc | ||
|
|
94b5edeb53 | ||
|
|
ab7e6006d6 | ||
|
|
18bfcdd05c | ||
|
|
71d63ed72e | ||
|
|
d75c40734a | ||
|
|
5b23c3f26f | ||
|
|
00efdc84ba | ||
|
|
91a61da9b1 | ||
|
|
ef9b636e2d | ||
|
|
2709c0009a | ||
|
|
dd7e8f5f64 | ||
|
|
d2a68364c4 | ||
|
|
7e1081139d | ||
|
|
18473cf498 | ||
|
|
4df417d059 | ||
|
|
5d80a9178b | ||
|
|
8a25d3a71a | ||
|
|
d10f8e1d43 | ||
|
|
14cc317ba4 | ||
|
|
e1957c6ebd | ||
|
|
8cd5a992bf | ||
|
|
947f0b23cc | ||
|
|
f780504d12 | ||
|
|
bfc072addf | ||
|
|
2a18da257c | ||
|
|
6e01e8c1c8 | ||
|
|
9f659bf07f | ||
|
|
35c4bc20d9 | ||
|
|
218dc2ccda | ||
|
|
827cbcd37c | ||
|
|
cb7a1c1cbf | ||
|
|
7878958c0d | ||
|
|
ce036244c9 | ||
|
|
48cf1e413c | ||
|
|
97460585d9 | ||
|
|
f745847ef7 | ||
|
|
6549aef245 | ||
|
|
50376faa7b | ||
|
|
4b61c6b669 | ||
|
|
79d64c4954 | ||
|
|
74cd5abdd1 | ||
|
|
28c3f12104 | ||
|
|
c884819135 | ||
|
|
05921a9a7a | ||
|
|
d0215a58e7 | ||
|
|
937e7b7d7c | ||
|
|
aee8ef661a |
63
.buildkite/run-benchmarks.sh
Normal file
63
.buildkite/run-benchmarks.sh
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
# cd into parent directory of this file
|
||||||
|
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||||
|
|
||||||
|
(wget && curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
|
||||||
|
# run benchmarks and upload the result to buildkite
|
||||||
|
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||||
|
bench_latency_exit_code=$?
|
||||||
|
|
||||||
|
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||||
|
bench_throughput_exit_code=$?
|
||||||
|
|
||||||
|
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
||||||
|
server_pid=$!
|
||||||
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
|
|
||||||
|
# wait for server to start, timeout after 600 seconds
|
||||||
|
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||||
|
python3 benchmarks/benchmark_serving.py \
|
||||||
|
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--num-prompts 20 \
|
||||||
|
--endpoint /v1/completions \
|
||||||
|
--tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt
|
||||||
|
bench_serving_exit_code=$?
|
||||||
|
kill $server_pid
|
||||||
|
|
||||||
|
# write the results into a markdown file
|
||||||
|
echo "### Latency Benchmarks" >> benchmark_results.md
|
||||||
|
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
|
||||||
|
echo "" >> benchmark_results.md
|
||||||
|
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
|
||||||
|
|
||||||
|
echo "### Throughput Benchmarks" >> benchmark_results.md
|
||||||
|
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
|
||||||
|
echo "" >> benchmark_results.md
|
||||||
|
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
|
||||||
|
|
||||||
|
echo "### Serving Benchmarks" >> benchmark_results.md
|
||||||
|
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||||
|
echo "" >> benchmark_results.md
|
||||||
|
tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines
|
||||||
|
|
||||||
|
# upload the results to buildkite
|
||||||
|
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||||
|
|
||||||
|
# exit with the exit code of the benchmarks
|
||||||
|
if [ $bench_latency_exit_code -ne 0 ]; then
|
||||||
|
exit $bench_latency_exit_code
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $bench_throughput_exit_code -ne 0 ]; then
|
||||||
|
exit $bench_throughput_exit_code
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $bench_serving_exit_code -ne 0 ]; then
|
||||||
|
exit $bench_serving_exit_code
|
||||||
|
fi
|
||||||
51
.buildkite/test-pipeline.yaml
Normal file
51
.buildkite/test-pipeline.yaml
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# In this file, you can add more tests to run either by adding a new step or
|
||||||
|
# adding a new command to an existing step. See different options here for examples.
|
||||||
|
# This script will be feed into Jinja template in `test-template.j2` to generate
|
||||||
|
# the final pipeline yaml file.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- label: Regression Test
|
||||||
|
command: pytest -v -s test_regression.py
|
||||||
|
working_dir: "/vllm-workspace/tests" # optional
|
||||||
|
|
||||||
|
- label: AsyncEngine Test
|
||||||
|
command: pytest -v -s async_engine
|
||||||
|
|
||||||
|
- label: Distributed Test
|
||||||
|
command: pytest -v -s test_comm_ops.py
|
||||||
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
|
|
||||||
|
- label: Engine Test
|
||||||
|
command: pytest -v -s engine
|
||||||
|
|
||||||
|
- label: Entrypoints Test
|
||||||
|
command: pytest -v -s entrypoints
|
||||||
|
|
||||||
|
- label: Kernels Test
|
||||||
|
command: pytest -v -s kernels
|
||||||
|
soft_fail: true
|
||||||
|
|
||||||
|
- label: Models Test
|
||||||
|
commands:
|
||||||
|
- pytest -v -s models --forked
|
||||||
|
soft_fail: true
|
||||||
|
|
||||||
|
- label: Prefix Caching Test
|
||||||
|
commands:
|
||||||
|
- pytest -v -s prefix_caching
|
||||||
|
|
||||||
|
- label: Samplers Test
|
||||||
|
command: pytest -v -s samplers --forked
|
||||||
|
|
||||||
|
- label: Worker Test
|
||||||
|
command: pytest -v -s worker
|
||||||
|
|
||||||
|
- label: LoRA Test
|
||||||
|
command: pytest -v -s lora
|
||||||
|
|
||||||
|
- label: Benchmarks
|
||||||
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
|
commands:
|
||||||
|
- pip install aiohttp
|
||||||
|
- bash run-benchmarks.sh
|
||||||
54
.buildkite/test-template.j2
Normal file
54
.buildkite/test-template.j2
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
|
||||||
|
{% set default_num_gpu = 1 %}
|
||||||
|
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- label: ":docker: build image"
|
||||||
|
commands:
|
||||||
|
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
|
- "docker push {{ docker_image }}"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- wait
|
||||||
|
|
||||||
|
{% for step in steps %}
|
||||||
|
- label: "{{ step.label }}"
|
||||||
|
agents:
|
||||||
|
queue: kubernetes
|
||||||
|
soft_fail: {{ step.soft_fail or false }}
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
plugins:
|
||||||
|
- kubernetes:
|
||||||
|
podSpec:
|
||||||
|
volumes:
|
||||||
|
- name: dshm
|
||||||
|
emptyDir:
|
||||||
|
medium: Memory
|
||||||
|
containers:
|
||||||
|
- image: "{{ docker_image }}"
|
||||||
|
command: ["bash"]
|
||||||
|
args:
|
||||||
|
- "-c"
|
||||||
|
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||||
|
limits:
|
||||||
|
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||||
|
env:
|
||||||
|
- name: HF_TOKEN
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hf-token-secret
|
||||||
|
key: token
|
||||||
|
volumeMounts:
|
||||||
|
- mountPath: /dev/shm
|
||||||
|
name: dshm
|
||||||
|
{% endfor %}
|
||||||
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@@ -0,0 +1 @@
|
|||||||
|
vllm/*.so
|
||||||
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt
|
|||||||
|
|
||||||
# Limit the number of parallel jobs to avoid OOM
|
# Limit the number of parallel jobs to avoid OOM
|
||||||
export MAX_JOBS=1
|
export MAX_JOBS=1
|
||||||
|
# Make sure punica is built for the release (for LoRA)
|
||||||
|
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
|
||||||
# Build
|
# Build
|
||||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|||||||
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@@ -28,4 +28,4 @@ jobs:
|
|||||||
pip install toml==0.10.2
|
pip install toml==0.10.2
|
||||||
- name: Running yapf
|
- name: Running yapf
|
||||||
run: |
|
run: |
|
||||||
yapf --diff --recursive vllm tests
|
yapf --diff --recursive .
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -181,3 +181,6 @@ _build/
|
|||||||
# hip files generated by PyTorch
|
# hip files generated by PyTorch
|
||||||
*.hip
|
*.hip
|
||||||
*_hip*
|
*_hip*
|
||||||
|
|
||||||
|
# Benchmark dataset
|
||||||
|
*.json
|
||||||
|
|||||||
39
Dockerfile
39
Dockerfile
@@ -1,7 +1,11 @@
|
|||||||
|
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||||
|
# to run the OpenAI compatible server.
|
||||||
|
|
||||||
|
#################### BASE BUILD IMAGE ####################
|
||||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y python3-pip
|
&& apt-get install -y python3-pip git
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
@@ -14,8 +18,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
COPY requirements-dev.txt requirements-dev.txt
|
COPY requirements-dev.txt requirements-dev.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
|
#################### BASE BUILD IMAGE ####################
|
||||||
|
|
||||||
# image to build pytorch extensions
|
|
||||||
|
#################### EXTENSION BUILD IMAGE ####################
|
||||||
FROM dev AS build
|
FROM dev AS build
|
||||||
|
|
||||||
# install build dependencies
|
# install build dependencies
|
||||||
@@ -30,6 +36,7 @@ COPY requirements.txt requirements.txt
|
|||||||
COPY pyproject.toml pyproject.toml
|
COPY pyproject.toml pyproject.toml
|
||||||
COPY vllm/__init__.py vllm/__init__.py
|
COPY vllm/__init__.py vllm/__init__.py
|
||||||
|
|
||||||
|
# cuda arch list used by torch
|
||||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||||
# max jobs used by Ninja to build extensions
|
# max jobs used by Ninja to build extensions
|
||||||
@@ -38,20 +45,30 @@ ENV MAX_JOBS=${max_jobs}
|
|||||||
# number of threads used by nvcc
|
# number of threads used by nvcc
|
||||||
ARG nvcc_threads=8
|
ARG nvcc_threads=8
|
||||||
ENV NVCC_THREADS=$nvcc_threads
|
ENV NVCC_THREADS=$nvcc_threads
|
||||||
|
# make sure punica kernels are built (for LoRA)
|
||||||
|
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
|
||||||
RUN python3 setup.py build_ext --inplace
|
RUN python3 setup.py build_ext --inplace
|
||||||
|
#################### EXTENSION Build IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
#################### TEST IMAGE ####################
|
||||||
# image to run unit testing suite
|
# image to run unit testing suite
|
||||||
FROM dev AS test
|
FROM dev AS test
|
||||||
|
|
||||||
# copy pytorch extensions separately to avoid having to rebuild
|
# copy pytorch extensions separately to avoid having to rebuild
|
||||||
# when python code changes
|
# when python code changes
|
||||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
WORKDIR /vllm-workspace
|
||||||
COPY tests tests
|
# ADD is used to preserve directory structure
|
||||||
COPY vllm vllm
|
ADD . /vllm-workspace/
|
||||||
|
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
||||||
|
# ignore build dependencies installation because we are using pre-complied extensions
|
||||||
|
RUN rm pyproject.toml
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
||||||
|
#################### TEST IMAGE ####################
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
|
|
||||||
|
|
||||||
|
#################### RUNTIME BASE IMAGE ####################
|
||||||
# use CUDA base as CUDA runtime dependencies are already installed via pip
|
# use CUDA base as CUDA runtime dependencies are already installed via pip
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
||||||
|
|
||||||
@@ -63,14 +80,10 @@ WORKDIR /workspace
|
|||||||
COPY requirements.txt requirements.txt
|
COPY requirements.txt requirements.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
#################### RUNTIME BASE IMAGE ####################
|
||||||
|
|
||||||
FROM vllm-base AS vllm
|
|
||||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
|
||||||
COPY vllm vllm
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
|
|
||||||
|
|
||||||
|
#################### OPENAI API SERVER ####################
|
||||||
# openai api server alternative
|
# openai api server alternative
|
||||||
FROM vllm-base AS vllm-openai
|
FROM vllm-base AS vllm-openai
|
||||||
# install additional dependencies for openai api server
|
# install additional dependencies for openai api server
|
||||||
@@ -81,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
|||||||
COPY vllm vllm
|
COPY vllm vllm
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||||
|
#################### OPENAI API SERVER ####################
|
||||||
|
|||||||
@@ -1,4 +1,24 @@
|
|||||||
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
|
# default base image
|
||||||
|
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
|
FROM $BASE_IMAGE
|
||||||
|
|
||||||
|
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
|
RUN echo "Base image is $BASE_IMAGE"
|
||||||
|
|
||||||
|
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||||
|
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
|
# this does not always work for all rocm versions
|
||||||
|
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
|
||||||
|
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
|
||||||
|
|
||||||
|
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||||
|
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||||
|
|
||||||
|
ARG FA_BRANCH="3d2b6f5"
|
||||||
|
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||||
|
|
||||||
# Install some basic utilities
|
# Install some basic utilities
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||||
@@ -37,17 +57,23 @@ RUN mkdir libs \
|
|||||||
&& cd libs \
|
&& cd libs \
|
||||||
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
||||||
&& cd flash-attention \
|
&& cd flash-attention \
|
||||||
&& git checkout 3d2b6f5 \
|
&& git checkout ${FA_BRANCH} \
|
||||||
&& git submodule update --init \
|
&& git submodule update --init \
|
||||||
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
|
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
||||||
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
|
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
||||||
|
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
||||||
&& python3 setup.py install \
|
&& python3 setup.py install \
|
||||||
&& cd ..
|
&& cd ..
|
||||||
|
|
||||||
COPY ./ /app/vllm
|
COPY ./ /app/vllm
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
RUN python3 -m pip install --upgrade pip
|
||||||
RUN pip install xformers==0.0.23 --no-deps
|
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
||||||
|
|
||||||
|
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||||
|
# Manually removed it so that later steps of numpy upgrade can continue
|
||||||
|
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||||
|
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||||
|
|
||||||
RUN cd /app \
|
RUN cd /app \
|
||||||
&& cd vllm \
|
&& cd vllm \
|
||||||
|
|||||||
18
README.md
18
README.md
@@ -16,8 +16,18 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**
|
||||||
|
|
||||||
|
We are thrilled to announce our second vLLM Meetup!
|
||||||
|
The vLLM team will share recent updates and roadmap.
|
||||||
|
We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
|
||||||
|
Please register [here](https://lu.ma/ygxbpzhl) and join us!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
- [2023/12] Added ROCm support to vLLM.
|
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||||
|
- [2023/12] Added ROCm 5.7 support to vLLM.
|
||||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||||
@@ -36,7 +46,7 @@ vLLM is fast with:
|
|||||||
- Efficient management of attention key and value memory with **PagedAttention**
|
- Efficient management of attention key and value memory with **PagedAttention**
|
||||||
- Continuous batching of incoming requests
|
- Continuous batching of incoming requests
|
||||||
- Fast model execution with CUDA/HIP graph
|
- Fast model execution with CUDA/HIP graph
|
||||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629)
|
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
||||||
- Optimized CUDA kernels
|
- Optimized CUDA kernels
|
||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
@@ -47,6 +57,8 @@ vLLM is flexible and easy to use with:
|
|||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs and AMD GPUs
|
- Support NVIDIA GPUs and AMD GPUs
|
||||||
|
- (Experimental) Prefix caching support
|
||||||
|
- (Experimental) Multi-lora support
|
||||||
|
|
||||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||||
|
|
||||||
@@ -68,6 +80,8 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
|||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
|
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
||||||
|
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
|
|||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
enforce_eager=args.enforce_eager,
|
enforce_eager=args.enforce_eager,
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@@ -65,7 +66,9 @@ def main(args: argparse.Namespace):
|
|||||||
if args.profile:
|
if args.profile:
|
||||||
profile_dir = args.profile_result_dir
|
profile_dir = args.profile_result_dir
|
||||||
if not profile_dir:
|
if not profile_dir:
|
||||||
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
profile_dir = Path(
|
||||||
|
"."
|
||||||
|
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||||
run_to_completion(profile_dir=args.profile_result_dir)
|
run_to_completion(profile_dir=args.profile_result_dir)
|
||||||
return
|
return
|
||||||
@@ -115,6 +118,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--enforce-eager',
|
parser.add_argument('--enforce-eager',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='enforce eager mode and disable CUDA graph')
|
help='enforce eager mode and disable CUDA graph')
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=['auto', 'fp8_e5m2'],
|
||||||
|
default='auto',
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--profile',
|
'--profile',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -123,9 +133,7 @@ if __name__ == '__main__':
|
|||||||
'--profile-result-dir',
|
'--profile-result-dir',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help=('path to save the pytorch profiler output. Can be visualized '
|
||||||
'path to save the pytorch profiler output. Can be visualized '
|
'with ui.perfetto.dev or Tensorboard.'))
|
||||||
'with ui.perfetto.dev or Tensorboard.'
|
|
||||||
))
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from typing import AsyncGenerator, List, Tuple
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@@ -40,15 +41,10 @@ def sample_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
data for data in dataset
|
|
||||||
if len(data["conversations"]) >= 2
|
|
||||||
]
|
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
@@ -96,15 +92,9 @@ async def get_request(
|
|||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
async def send_request(
|
async def send_request(backend: str, model: str, api_url: str, prompt: str,
|
||||||
backend: str,
|
prompt_len: int, output_len: int, best_of: int,
|
||||||
api_url: str,
|
use_beam_search: bool, pbar: tqdm) -> None:
|
||||||
prompt: str,
|
|
||||||
prompt_len: int,
|
|
||||||
output_len: int,
|
|
||||||
best_of: int,
|
|
||||||
use_beam_search: bool,
|
|
||||||
) -> None:
|
|
||||||
request_start_time = time.perf_counter()
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
headers = {"User-Agent": "Benchmark Client"}
|
headers = {"User-Agent": "Benchmark Client"}
|
||||||
@@ -120,6 +110,8 @@ async def send_request(
|
|||||||
"ignore_eos": True,
|
"ignore_eos": True,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
if model is not None:
|
||||||
|
pload["model"] = model
|
||||||
elif backend == "tgi":
|
elif backend == "tgi":
|
||||||
assert not use_beam_search
|
assert not use_beam_search
|
||||||
params = {
|
params = {
|
||||||
@@ -137,7 +129,8 @@ async def send_request(
|
|||||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
while True:
|
while True:
|
||||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
async with session.post(api_url, headers=headers,
|
||||||
|
json=pload) as response:
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk, _ in response.content.iter_chunks():
|
async for chunk, _ in response.content.iter_chunks():
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
@@ -151,10 +144,12 @@ async def send_request(
|
|||||||
request_end_time = time.perf_counter()
|
request_end_time = time.perf_counter()
|
||||||
request_latency = request_end_time - request_start_time
|
request_latency = request_end_time - request_start_time
|
||||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
async def benchmark(
|
async def benchmark(
|
||||||
backend: str,
|
backend: str,
|
||||||
|
model: str,
|
||||||
api_url: str,
|
api_url: str,
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
best_of: int,
|
best_of: int,
|
||||||
@@ -162,13 +157,15 @@ async def benchmark(
|
|||||||
request_rate: float,
|
request_rate: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: List[asyncio.Task] = []
|
||||||
|
pbar = tqdm(total=len(input_requests))
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
prompt, prompt_len, output_len = request
|
prompt, prompt_len, output_len = request
|
||||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
task = asyncio.create_task(
|
||||||
prompt_len, output_len,
|
send_request(backend, model, api_url, prompt, prompt_len,
|
||||||
best_of, use_beam_search))
|
output_len, best_of, use_beam_search, pbar))
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
@@ -176,13 +173,15 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
api_url = f"http://{args.host}:{args.port}/generate"
|
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
|
||||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
tokenizer = get_tokenizer(args.tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
asyncio.run(
|
||||||
args.use_beam_search, args.request_rate))
|
benchmark(args.backend, args.model, api_url, input_requests,
|
||||||
|
args.best_of, args.use_beam_search, args.request_rate))
|
||||||
benchmark_end_time = time.perf_counter()
|
benchmark_end_time = time.perf_counter()
|
||||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||||
print(f"Total time: {benchmark_time:.2f} s")
|
print(f"Total time: {benchmark_time:.2f} s")
|
||||||
@@ -196,10 +195,8 @@ def main(args: argparse.Namespace):
|
|||||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||||
])
|
])
|
||||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||||
avg_per_output_token_latency = np.mean([
|
avg_per_output_token_latency = np.mean(
|
||||||
latency / output_len
|
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
|
||||||
for _, output_len, latency in REQUEST_LATENCY
|
|
||||||
])
|
|
||||||
print("Average latency per output token: "
|
print("Average latency per output token: "
|
||||||
f"{avg_per_output_token_latency:.2f} s")
|
f"{avg_per_output_token_latency:.2f} s")
|
||||||
|
|
||||||
@@ -207,27 +204,46 @@ def main(args: argparse.Namespace):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Benchmark the online serving throughput.")
|
description="Benchmark the online serving throughput.")
|
||||||
parser.add_argument("--backend", type=str, default="vllm",
|
parser.add_argument("--backend",
|
||||||
|
type=str,
|
||||||
|
default="vllm",
|
||||||
choices=["vllm", "tgi"])
|
choices=["vllm", "tgi"])
|
||||||
|
parser.add_argument("--protocol",
|
||||||
|
type=str,
|
||||||
|
default="http",
|
||||||
|
choices=["http", "https"])
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--dataset", type=str, required=True,
|
parser.add_argument("--endpoint", type=str, default="/generate")
|
||||||
|
parser.add_argument("--model", type=str, default=None)
|
||||||
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
parser.add_argument("--tokenizer", type=str, required=True,
|
parser.add_argument("--tokenizer",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Name or path of the tokenizer.")
|
help="Name or path of the tokenizer.")
|
||||||
parser.add_argument("--best-of", type=int, default=1,
|
parser.add_argument("--best-of",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help="Generates `best_of` sequences per prompt and "
|
help="Generates `best_of` sequences per prompt and "
|
||||||
"returns the best one.")
|
"returns the best one.")
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
help="Number of prompts to process.")
|
help="Number of prompts to process.")
|
||||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
parser.add_argument("--request-rate",
|
||||||
|
type=float,
|
||||||
|
default=float("inf"),
|
||||||
help="Number of requests per second. If this is inf, "
|
help="Number of requests per second. If this is inf, "
|
||||||
"then all the requests are sent at time 0. "
|
"then all the requests are sent at time 0. "
|
||||||
"Otherwise, we use Poisson process to synthesize "
|
"Otherwise, we use Poisson process to synthesize "
|
||||||
"the request arrival times.")
|
"the request arrival times.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument('--trust-remote-code', action='store_true',
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
help='trust remote code from huggingface')
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def run_vllm(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_model_len: Optional[int],
|
max_model_len: Optional[int],
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
|
kv_cache_dtype: str,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@@ -83,6 +84,7 @@ def run_vllm(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.quantization, args.tensor_parallel_size,
|
args.quantization, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search,
|
args.seed, args.n, args.use_beam_search,
|
||||||
args.trust_remote_code, args.dtype,
|
args.trust_remote_code, args.dtype,
|
||||||
args.max_model_len, args.enforce_eager)
|
args.max_model_len, args.enforce_eager,
|
||||||
|
args.kv_cache_dtype)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@@ -284,6 +287,13 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--enforce-eager",
|
parser.add_argument("--enforce-eager",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="enforce eager execution")
|
help="enforce eager execution")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
from typing import Optional
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
NUM_BLOCKS = 1024
|
||||||
@@ -23,6 +25,7 @@ def main(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
do_profile: bool,
|
do_profile: bool,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
@@ -59,15 +62,10 @@ def main(
|
|||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV cache.
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
key_caches, value_caches = create_kv_caches_with_random(
|
||||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
|
||||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
dtype)
|
||||||
key_cache.uniform_(-scale, scale)
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
|
||||||
value_cache = torch.empty(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device="cuda")
|
|
||||||
value_cache.uniform_(-scale, scale)
|
|
||||||
|
|
||||||
# Prepare for the paged attention kernel.
|
# Prepare for the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@@ -106,6 +104,7 @@ def main(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@@ -123,6 +122,7 @@ def main(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
@@ -168,16 +168,18 @@ if __name__ == '__main__':
|
|||||||
default="half")
|
default="half")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
if args.num_query_heads % args.num_kv_heads != 0:
|
if args.num_query_heads % args.num_kv_heads != 0:
|
||||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||||
dtype_to_torch_dtype = {
|
|
||||||
"half": torch.half,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
"float": torch.float,
|
|
||||||
}
|
|
||||||
main(
|
main(
|
||||||
version=args.version,
|
version=args.version,
|
||||||
num_seqs=args.batch_size,
|
num_seqs=args.batch_size,
|
||||||
@@ -187,7 +189,8 @@ if __name__ == '__main__':
|
|||||||
head_size=args.head_size,
|
head_size=args.head_size,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
use_alibi=args.use_alibi,
|
use_alibi=args.use_alibi,
|
||||||
dtype=dtype_to_torch_dtype[args.dtype],
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
do_profile=args.profile,
|
do_profile=args.profile,
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,3 +4,4 @@
|
|||||||
#include "dtype_float16.cuh"
|
#include "dtype_float16.cuh"
|
||||||
#include "dtype_float32.cuh"
|
#include "dtype_float32.cuh"
|
||||||
#include "dtype_bfloat16.cuh"
|
#include "dtype_bfloat16.cuh"
|
||||||
|
#include "dtype_fp8_e5m2.cuh"
|
||||||
|
|||||||
@@ -25,6 +25,7 @@
|
|||||||
|
|
||||||
#include "attention_dtypes.h"
|
#include "attention_dtypes.h"
|
||||||
#include "attention_utils.cuh"
|
#include "attention_utils.cuh"
|
||||||
|
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
@@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@@ -145,6 +148,9 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||||
@@ -176,7 +182,7 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||||
// Each thread group fetches x elements from the key at a time.
|
// Each thread group fetches x elements from the key at a time.
|
||||||
constexpr int x = 16 / sizeof(scalar_t);
|
constexpr int x = 16 / sizeof(cache_t);
|
||||||
float qk_max = -FLT_MAX;
|
float qk_max = -FLT_MAX;
|
||||||
|
|
||||||
// Iterate over the key blocks.
|
// Iterate over the key blocks.
|
||||||
@@ -202,13 +208,23 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride
|
+ kv_head_idx * kv_head_stride
|
||||||
+ physical_block_offset * x;
|
+ physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
|
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
@@ -282,6 +298,9 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||||
|
#endif
|
||||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
@@ -307,14 +326,25 @@ __device__ void paged_attention_kernel(
|
|||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride;
|
+ kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
V_vec v_vec;
|
||||||
|
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
|
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
}
|
||||||
if (block_idx == num_context_blocks - 1) {
|
if (block_idx == num_context_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
// we should explicitly zero out the values since they may contain NaNs.
|
// we should explicitly zero out the values since they may contain NaNs.
|
||||||
@@ -395,14 +425,16 @@ __device__ void paged_attention_kernel(
|
|||||||
// Grid: (num_heads, num_seqs, 1).
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS>
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||||
@@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int num_kv_heads, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
@@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||||
q_stride, kv_block_stride, kv_head_stride);
|
q_stride, kv_block_stride, kv_head_stride);
|
||||||
@@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
shared_mem_size); \
|
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
@@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int NUM_THREADS = 128>
|
int NUM_THREADS = 128>
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@@ -602,8 +638,8 @@ void paged_attention_v1_launcher(
|
|||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@@ -647,35 +683,35 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||||
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V1_LAUNCHER(T, 8); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_LAUNCHER(T, 16); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V1_LAUNCHER(T, 32); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
@@ -689,20 +725,36 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
const std::string& kv_cache_dtype) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
if (kv_cache_dtype == "auto") {
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
|
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, \
|
||||||
max_logits_ptr, \
|
max_logits_ptr, \
|
||||||
@@ -730,7 +782,9 @@ void paged_attention_v1(
|
|||||||
|
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int NUM_THREADS = 128,
|
int NUM_THREADS = 128,
|
||||||
int PARTITION_SIZE = 512>
|
int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
@@ -768,8 +822,8 @@ void paged_attention_v2_launcher(
|
|||||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@@ -816,38 +870,38 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||||
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
exp_sums, \
|
exp_sums, \
|
||||||
max_logits, \
|
max_logits, \
|
||||||
tmp_out, \
|
tmp_out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
num_kv_heads, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, 8); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, 16); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, 32); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
@@ -864,15 +918,30 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
const std::string& kv_cache_dtype) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
if (kv_cache_dtype == "auto") {
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
// fp8 vector types for quantization of kv cache
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 1> {
|
||||||
|
using Type = uint8_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 2> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 4> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 8> {
|
||||||
|
using Type = uint2;
|
||||||
|
};
|
||||||
|
#endif // ENABLE_FP8_E5M2
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -20,7 +20,8 @@ void reshape_and_cache(
|
|||||||
torch::Tensor& value,
|
torch::Tensor& value,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping);
|
torch::Tensor& slot_mapping,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void gather_cached_kv(
|
void gather_cached_kv(
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
@@ -28,3 +29,8 @@ void gather_cached_kv(
|
|||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping);
|
torch::Tensor& slot_mapping);
|
||||||
|
|
||||||
|
// Just for unittest
|
||||||
|
void convert_fp8_e5m2(
|
||||||
|
torch::Tensor& src_cache,
|
||||||
|
torch::Tensor& dst_cache);
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@@ -34,7 +35,7 @@ void swap_blocks(
|
|||||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||||
|
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(src_device);
|
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
for (const auto& pair : block_mapping) {
|
for (const auto& pair : block_mapping) {
|
||||||
@@ -131,7 +132,7 @@ void copy_blocks(
|
|||||||
dim3 block(std::min(1024, numel_per_block));
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
@@ -143,12 +144,12 @@ void copy_blocks(
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride,
|
const int key_stride,
|
||||||
const int value_stride,
|
const int value_stride,
|
||||||
@@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ head_idx * head_size * block_size
|
+ head_idx * head_size * block_size
|
||||||
+ head_offset * block_size
|
+ head_offset * block_size
|
||||||
+ block_offset;
|
+ block_offset;
|
||||||
key_cache[tgt_key_idx] = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
value_cache[tgt_value_idx] = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
|
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||||
|
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), \
|
||||||
|
key_stride, \
|
||||||
|
value_stride, \
|
||||||
|
num_heads, \
|
||||||
|
head_size, \
|
||||||
|
block_size, \
|
||||||
|
x);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping) // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
|
const std::string& kv_cache_dtype)
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
@@ -212,23 +239,25 @@ void reshape_and_cache(
|
|||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
if (kv_cache_dtype == "auto") {
|
||||||
key.scalar_type(),
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
"reshape_and_cache_kernel",
|
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||||
[&] {
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||||
key.data_ptr<scalar_t>(),
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
value.data_ptr<scalar_t>(),
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
key_cache.data_ptr<scalar_t>(),
|
}
|
||||||
value_cache.data_ptr<scalar_t>(),
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
slot_mapping.data_ptr<int64_t>(),
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
key_stride,
|
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||||
value_stride,
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
num_heads,
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||||
head_size,
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
block_size,
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||||
x);
|
}
|
||||||
});
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -373,7 +402,7 @@ void gather_cached_kv(
|
|||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"gather_cached_kv_kernel_optimized",
|
"gather_cached_kv_kernel_optimized",
|
||||||
[&] {
|
[&] {
|
||||||
@@ -391,3 +420,55 @@ void gather_cached_kv(
|
|||||||
x);
|
x);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename Tout, typename Tin>
|
||||||
|
__global__ void convert_fp8_e5m2_kernel(
|
||||||
|
const Tin* __restrict__ src_cache,
|
||||||
|
Tout* __restrict__ dst_cache,
|
||||||
|
const int64_t block_stride) {
|
||||||
|
const int64_t block_idx = blockIdx.x;
|
||||||
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
|
int64_t idx = block_idx * block_stride + i;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||||
|
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||||
|
block_stride);
|
||||||
|
|
||||||
|
void convert_fp8_e5m2(
|
||||||
|
torch::Tensor& src_cache,
|
||||||
|
torch::Tensor& dst_cache)
|
||||||
|
{
|
||||||
|
int64_t num_blocks = src_cache.size(0);
|
||||||
|
int64_t block_stride = src_cache.stride(0);
|
||||||
|
|
||||||
|
dim3 grid(num_blocks);
|
||||||
|
dim3 block(std::min(block_stride, int64_t(512)));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,3 +5,6 @@
|
|||||||
int get_device_attribute(
|
int get_device_attribute(
|
||||||
int attribute,
|
int attribute,
|
||||||
int device_id);
|
int device_id);
|
||||||
|
|
||||||
|
int get_max_shared_memory_per_block_device_attribute(
|
||||||
|
int device_id);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
int get_device_attribute(
|
int get_device_attribute(
|
||||||
int attribute,
|
int attribute,
|
||||||
@@ -15,3 +16,20 @@ int get_device_attribute(
|
|||||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int get_max_shared_memory_per_block_device_attribute(
|
||||||
|
int device_id)
|
||||||
|
{
|
||||||
|
int attribute;
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
|
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||||
|
#else
|
||||||
|
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return get_device_attribute(attribute, device_id);
|
||||||
|
}
|
||||||
|
|||||||
148
csrc/custom_all_reduce.cu
Normal file
148
csrc/custom_all_reduce.cu
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
#include <ATen/cuda/Exceptions.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "custom_all_reduce.cuh"
|
||||||
|
|
||||||
|
// fake pointer type
|
||||||
|
using fptr_t = uint64_t;
|
||||||
|
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||||
|
|
||||||
|
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets, int rank,
|
||||||
|
bool full_nvlink) {
|
||||||
|
int world_size = offsets.size();
|
||||||
|
if (world_size > 8)
|
||||||
|
throw std::invalid_argument("world size > 8 is not supported");
|
||||||
|
if (world_size % 2 != 0)
|
||||||
|
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||||
|
if (world_size != handles.size())
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"handles length should equal to offsets length");
|
||||||
|
if (rank < 0 || rank >= world_size)
|
||||||
|
throw std::invalid_argument("invalid rank passed in");
|
||||||
|
|
||||||
|
cudaIpcMemHandle_t ipc_handles[8];
|
||||||
|
for (int i = 0; i < world_size; i++) {
|
||||||
|
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||||
|
}
|
||||||
|
return (fptr_t) new vllm::CustomAllreduce(
|
||||||
|
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||||
|
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||||
|
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||||
|
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||||
|
* dimension). Currently, we require this because stride information is not
|
||||||
|
* passed into the kernels and we treat input tensors as flat.
|
||||||
|
*
|
||||||
|
* Examples
|
||||||
|
* A = torch.zeros(3, 3, 3)
|
||||||
|
* 1. A: OK
|
||||||
|
* 2. A[1:]: OK
|
||||||
|
* 3. A.permute(2, 0, 1): OK
|
||||||
|
* 4. A[1:].permute(2, 0, 1): OK
|
||||||
|
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||||
|
* 6. A[:, 1:, 1:]: Not OK
|
||||||
|
*/
|
||||||
|
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||||
|
return t.is_contiguous() ||
|
||||||
|
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||||
|
t.numel() * t.element_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||||
|
bool full_nvlink) {
|
||||||
|
auto inp_size = inp.numel() * inp.element_size();
|
||||||
|
// custom allreduce requires input byte size to be multiples of 16
|
||||||
|
if (inp_size % 16 != 0) return false;
|
||||||
|
if (!_is_weak_contiguous(inp)) return false;
|
||||||
|
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||||
|
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
||||||
|
// <= 512k
|
||||||
|
return world_size <= 4 && inp_size <= 512 * 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
TORCH_CHECK(_is_weak_contiguous(out));
|
||||||
|
switch (out.scalar_type()) {
|
||||||
|
case at::ScalarType::Float: {
|
||||||
|
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||||
|
reinterpret_cast<float *>(out.data_ptr()),
|
||||||
|
out.numel());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case at::ScalarType::Half: {
|
||||||
|
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||||
|
reinterpret_cast<half *>(out.data_ptr()),
|
||||||
|
out.numel());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
|
case at::ScalarType::BFloat16: {
|
||||||
|
fa->allreduce<nv_bfloat16>(
|
||||||
|
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
||||||
|
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"custom allreduce only supports float32, float16 and bfloat16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||||
|
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||||
|
_all_reduce(_fa, inp, out, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||||
|
torch::Tensor &out) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
|
auto input_size = inp.numel() * inp.element_size();
|
||||||
|
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||||
|
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||||
|
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||||
|
"registered buffer is too small to contain the input");
|
||||||
|
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||||
|
input_size, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
_all_reduce(_fa, reg_buffer, out, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispose(fptr_t _fa) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
delete fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
int meta_size() { return sizeof(vllm::Metadata); }
|
||||||
|
|
||||||
|
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
|
fptr_t _fa) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
return fa->get_graph_buffer_ipc_meta();
|
||||||
|
}
|
||||||
|
|
||||||
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||||
|
const std::vector<std::vector<int64_t>> &offsets) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
fa->register_graph_buffers(handles, offsets);
|
||||||
|
}
|
||||||
562
csrc/custom_all_reduce.cuh
Normal file
562
csrc/custom_all_reduce.cuh
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define CUDACHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
struct Signal {
|
||||||
|
alignas(64) union {
|
||||||
|
uint64_t flag;
|
||||||
|
unsigned char data[8];
|
||||||
|
} start;
|
||||||
|
alignas(64) union {
|
||||||
|
uint64_t flag;
|
||||||
|
unsigned char data[8];
|
||||||
|
} end;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Metadata {
|
||||||
|
alignas(128) Signal sg;
|
||||||
|
alignas(128) int counter;
|
||||||
|
};
|
||||||
|
static_assert(offsetof(Metadata, counter) == 128);
|
||||||
|
static_assert(sizeof(Metadata) == 256);
|
||||||
|
|
||||||
|
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||||
|
|
||||||
|
struct RankSignals {
|
||||||
|
volatile Signal *signals[8];
|
||||||
|
};
|
||||||
|
|
||||||
|
// like std::array, but aligned
|
||||||
|
template <typename T, int sz>
|
||||||
|
struct __align__(alignof(T) * sz) array_t {
|
||||||
|
T data[sz];
|
||||||
|
using type = T;
|
||||||
|
static constexpr int size = sz;
|
||||||
|
};
|
||||||
|
|
||||||
|
// use packed type to maximize memory efficiency
|
||||||
|
// goal: generate ld.128 and st.128 instructions
|
||||||
|
template <typename T>
|
||||||
|
struct packed_t {
|
||||||
|
// the (P)acked type for load/store
|
||||||
|
using P = array_t<T, 16 / sizeof(T)>;
|
||||||
|
// the (A)ccumulator type for reduction
|
||||||
|
using A = array_t<float, 16 / sizeof(T)>;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DINLINE __device__ __forceinline__
|
||||||
|
|
||||||
|
// scalar cast functions
|
||||||
|
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
DINLINE T downcast_s(float val);
|
||||||
|
template <>
|
||||||
|
DINLINE half downcast_s(float val) {
|
||||||
|
return __float2half(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// scalar add functions
|
||||||
|
// for some reason when compiling with Pytorch, the + operator for half and
|
||||||
|
// bfloat is disabled so we call the intrinsics directly
|
||||||
|
DINLINE half &assign_add(half &a, half b) {
|
||||||
|
a = __hadd(a, b);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||||
|
|
||||||
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
|
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||||
|
template <>
|
||||||
|
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||||
|
return __float2bfloat16(val);
|
||||||
|
}
|
||||||
|
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
||||||
|
a = __hadd(a, b);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
assign_add(a.data[i], b.data[i]);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||||
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
|
return val;
|
||||||
|
} else {
|
||||||
|
array_t<float, N> out;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
out.data[i] = upcast_s(val.data[i]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename O>
|
||||||
|
DINLINE O downcast(array_t<float, O::size> val) {
|
||||||
|
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||||
|
return val;
|
||||||
|
} else {
|
||||||
|
O out;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < O::size; i++) {
|
||||||
|
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute flag at compile time
|
||||||
|
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
||||||
|
auto m = std::numeric_limits<uint64_t>::max();
|
||||||
|
return m >> ((8 - ngpus) * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int ngpus>
|
||||||
|
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||||
|
int rank) {
|
||||||
|
constexpr auto FLAG = compute_flag(ngpus);
|
||||||
|
if (blockIdx.x == 0) {
|
||||||
|
if (threadIdx.x < ngpus)
|
||||||
|
// simultaneously write to the corresponding byte to all other ranks.
|
||||||
|
// Latency = 1 p2p write
|
||||||
|
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
||||||
|
else if (threadIdx.x == 32)
|
||||||
|
// reset
|
||||||
|
meta->sg.end.flag = 0;
|
||||||
|
}
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
while (meta->sg.start.flag != FLAG)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int ngpus, bool final_sync = false>
|
||||||
|
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||||
|
int rank) {
|
||||||
|
constexpr auto FLAG = compute_flag(ngpus);
|
||||||
|
__syncthreads();
|
||||||
|
__shared__ int num;
|
||||||
|
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Only the last completing block can perform the end synchronization
|
||||||
|
// This can ensures when the final busy wait ends, all ranks must have
|
||||||
|
// finished reading each other's buffer.
|
||||||
|
if (num == gridDim.x - 1) {
|
||||||
|
if (threadIdx.x == 32) {
|
||||||
|
// reset in a different warp
|
||||||
|
meta->counter = 0;
|
||||||
|
meta->sg.start.flag = 0;
|
||||||
|
} else if (threadIdx.x < ngpus) {
|
||||||
|
// simultaneously write to the corresponding byte to all other ranks.
|
||||||
|
// Latency = 1 p2p write
|
||||||
|
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
||||||
|
}
|
||||||
|
// if this is the final sync, only one block needs it
|
||||||
|
// because kernel exit can serve as sync
|
||||||
|
if constexpr (final_sync) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
while (meta->sg.end.flag != FLAG)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (!final_sync) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
while (meta->sg.end.flag != FLAG)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P, int ngpus, typename A>
|
||||||
|
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||||
|
A tmp = upcast(ptrs[0][idx]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 1; i < ngpus; i++) {
|
||||||
|
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||||
|
}
|
||||||
|
return downcast<P>(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ngpus>
|
||||||
|
__global__ void __launch_bounds__(512, 1)
|
||||||
|
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||||
|
volatile Metadata *meta, T *__restrict__ result,
|
||||||
|
int rank, int size) {
|
||||||
|
using P = typename packed_t<T>::P;
|
||||||
|
using A = typename packed_t<T>::A;
|
||||||
|
// note: we don't reorder the address so the accumulation order is the same
|
||||||
|
// for all ranks, ensuring bitwise identical results
|
||||||
|
auto dp = *_dp;
|
||||||
|
start_sync<ngpus>(sg, meta, rank);
|
||||||
|
// do the actual reduction
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
((P *)result)[idx] =
|
||||||
|
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||||
|
}
|
||||||
|
end_sync<ngpus, true>(sg, meta, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P>
|
||||||
|
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||||
|
return (P *)(((Metadata *)sg) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ngpus>
|
||||||
|
__global__ void __launch_bounds__(512, 1)
|
||||||
|
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||||
|
volatile Metadata *meta, T *__restrict__ result,
|
||||||
|
int rank, int size) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
using P = typename packed_t<T>::P;
|
||||||
|
using A = typename packed_t<T>::A;
|
||||||
|
int part = size / ngpus;
|
||||||
|
int start = rank * part;
|
||||||
|
int end = rank == ngpus - 1 ? size : start + part;
|
||||||
|
const P *ptrs[ngpus];
|
||||||
|
P *tmps[ngpus];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < ngpus; i++) {
|
||||||
|
int target = (rank + i) % ngpus;
|
||||||
|
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||||
|
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||||
|
}
|
||||||
|
auto tmp_out = tmps[0];
|
||||||
|
start_sync<ngpus>(sg, meta, rank);
|
||||||
|
// stage 1: reduce scatter
|
||||||
|
for (int idx = start + tid; idx < end; idx += stride) {
|
||||||
|
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||||
|
}
|
||||||
|
// Maybe TODO: replace this with per-block release-acquire
|
||||||
|
// can save about 1-2us (not a lot though)
|
||||||
|
end_sync<ngpus>(sg, meta, rank);
|
||||||
|
|
||||||
|
// stage 2: allgather
|
||||||
|
for (int idx = tid; idx < part; idx += stride) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < ngpus; i++) {
|
||||||
|
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
||||||
|
((P *)result)[dst_idx] = tmps[i][idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// process the last larger partition
|
||||||
|
int remaining = size - part * ngpus;
|
||||||
|
if (tid < remaining) {
|
||||||
|
int dst_idx = tid + part * ngpus;
|
||||||
|
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// faster than this
|
||||||
|
// for (int idx = tid; idx < size; idx += stride) {
|
||||||
|
// int target_rank = idx / part;
|
||||||
|
// if (target_rank == ngpus) target_rank -= 1;
|
||||||
|
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ngpus>
|
||||||
|
__global__ void __launch_bounds__(512, 1)
|
||||||
|
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
||||||
|
volatile Metadata *meta,
|
||||||
|
T *__restrict__ result, int rank,
|
||||||
|
int size) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
using P = typename packed_t<T>::P;
|
||||||
|
using A = typename packed_t<T>::A;
|
||||||
|
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
||||||
|
constexpr int hg = ngpus / 2;
|
||||||
|
// Actually not quite half butterfly.
|
||||||
|
// This is an all-to-all within each group containing half of the ranks
|
||||||
|
// followed by cross-group add. Equivalent to half butterfly when there
|
||||||
|
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
||||||
|
const P *ptrs[hg];
|
||||||
|
{
|
||||||
|
int start = rank - rank % hg;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < hg; i++) {
|
||||||
|
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_sync<ngpus>(sg, meta, rank);
|
||||||
|
for (int idx = tid; idx < size; idx += stride) {
|
||||||
|
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
||||||
|
}
|
||||||
|
end_sync<ngpus>(sg, meta, rank);
|
||||||
|
|
||||||
|
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
||||||
|
// do the cross group reduction
|
||||||
|
for (int idx = tid; idx < size; idx += stride) {
|
||||||
|
auto tmp = tmp_out[idx];
|
||||||
|
packed_assign_add(tmp, src[idx]);
|
||||||
|
((P *)result)[idx] = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||||
|
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||||
|
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||||
|
|
||||||
|
class CustomAllreduce {
|
||||||
|
public:
|
||||||
|
int rank_;
|
||||||
|
int world_size_;
|
||||||
|
bool full_nvlink_;
|
||||||
|
|
||||||
|
// below are device pointers
|
||||||
|
RankSignals sg_;
|
||||||
|
std::unordered_map<void *, RankData *> buffers_;
|
||||||
|
Metadata *meta_;
|
||||||
|
|
||||||
|
// stores the registered device pointers from all ranks
|
||||||
|
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||||
|
std::vector<void *> graph_unreg_buffers_;
|
||||||
|
// a map from IPC handles to opened IPC pointers
|
||||||
|
std::map<IPC_KEY, char *> ipc_handles_;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||||
|
*
|
||||||
|
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
||||||
|
* so meta + 1 points to actual temporary buffer.
|
||||||
|
*
|
||||||
|
* note: this class does not own any device memory. Any required buffers
|
||||||
|
* are passed in from the constructor
|
||||||
|
*/
|
||||||
|
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
||||||
|
const cudaIpcMemHandle_t *handles,
|
||||||
|
const std::vector<int64_t> &offsets, int rank,
|
||||||
|
bool full_nvlink = true)
|
||||||
|
: rank_(rank),
|
||||||
|
world_size_(offsets.size()),
|
||||||
|
full_nvlink_(full_nvlink),
|
||||||
|
meta_(meta),
|
||||||
|
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||||
|
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||||
|
for (int i = 0; i < world_size_; i++) {
|
||||||
|
Metadata *rank_meta;
|
||||||
|
if (i != rank_) {
|
||||||
|
char *handle = open_ipc_handle(&handles[i]);
|
||||||
|
handle += offsets[i];
|
||||||
|
rank_meta = (Metadata *)handle;
|
||||||
|
} else {
|
||||||
|
rank_meta = meta_;
|
||||||
|
}
|
||||||
|
sg_.signals[i] = &rank_meta->sg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
char *open_ipc_handle(const void *ipc_handle) {
|
||||||
|
auto [it, new_handle] =
|
||||||
|
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
||||||
|
if (new_handle) {
|
||||||
|
char *ipc_ptr;
|
||||||
|
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
||||||
|
*((const cudaIpcMemHandle_t *)ipc_handle),
|
||||||
|
cudaIpcMemLazyEnablePeerAccess));
|
||||||
|
it->second = ipc_ptr;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||||
|
get_graph_buffer_ipc_meta() {
|
||||||
|
auto num_buffers = graph_unreg_buffers_.size();
|
||||||
|
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||||
|
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||||
|
std::vector<int64_t> offsets(num_buffers);
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
auto ptr = graph_unreg_buffers_[i];
|
||||||
|
void *base_ptr;
|
||||||
|
// note: must share the base address of each allocation, or we get wrong
|
||||||
|
// address
|
||||||
|
if (cuPointerGetAttribute(&base_ptr,
|
||||||
|
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||||
|
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||||
|
throw std::runtime_error("failed to get pointer attr");
|
||||||
|
CUDACHECK(cudaIpcGetMemHandle(
|
||||||
|
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||||
|
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||||
|
}
|
||||||
|
return std::make_pair(handles, offsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
void check_rank_data_capacity(size_t num = 1) {
|
||||||
|
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Rank data buffer is overflowed by " +
|
||||||
|
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void register_buffer(const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets, void *self) {
|
||||||
|
check_rank_data_capacity();
|
||||||
|
RankData data;
|
||||||
|
for (int i = 0; i < world_size_; i++) {
|
||||||
|
if (i != rank_) {
|
||||||
|
char *handle = open_ipc_handle(handles[i].data());
|
||||||
|
handle += offsets[i];
|
||||||
|
data.ptrs[i] = handle;
|
||||||
|
} else {
|
||||||
|
data.ptrs[i] = self;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto d_data = d_rank_data_base_++;
|
||||||
|
CUDACHECK(
|
||||||
|
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||||
|
buffers_[self] = d_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: when registering graph buffers, we intentionally choose to not
|
||||||
|
// deduplicate the addresses. That means if the allocator reuses some
|
||||||
|
// addresses, they will be registered again. This is to account for the remote
|
||||||
|
// possibility of different allocation patterns between ranks. For example,
|
||||||
|
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||||
|
// got a different address. IPC handles have internal reference counting
|
||||||
|
// mechanism so overhead should be small.
|
||||||
|
void register_graph_buffers(
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<std::vector<int64_t>> &offsets) {
|
||||||
|
auto num_buffers = graph_unreg_buffers_.size();
|
||||||
|
check_rank_data_capacity(num_buffers);
|
||||||
|
std::vector<RankData> rank_data(num_buffers);
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
auto self_ptr = graph_unreg_buffers_[i];
|
||||||
|
auto &rd = rank_data[i];
|
||||||
|
for (int j = 0; j < world_size_; j++) {
|
||||||
|
if (j != rank_) {
|
||||||
|
char *handle =
|
||||||
|
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||||
|
handle += offsets[j][i];
|
||||||
|
rd.ptrs[j] = handle;
|
||||||
|
} else {
|
||||||
|
rd.ptrs[j] = self_ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||||
|
sizeof(RankData) * num_buffers,
|
||||||
|
cudaMemcpyHostToDevice));
|
||||||
|
d_rank_data_base_ += num_buffers;
|
||||||
|
graph_unreg_buffers_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is the result after careful grid search. Using 36 blocks give the best
|
||||||
|
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||||
|
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||||
|
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||||
|
* will cause contention on NVLink bus.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
||||||
|
int threads = 512, int block_limit = 36) {
|
||||||
|
auto d = packed_t<T>::P::size;
|
||||||
|
if (size % d != 0)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"custom allreduce currently requires input length to be multiple "
|
||||||
|
"of " +
|
||||||
|
std::to_string(d));
|
||||||
|
|
||||||
|
RankData *ptrs;
|
||||||
|
cudaStreamCaptureStatus status;
|
||||||
|
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||||
|
if (status == cudaStreamCaptureStatusActive) {
|
||||||
|
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||||
|
graph_unreg_buffers_.push_back(input);
|
||||||
|
} else {
|
||||||
|
auto it = buffers_.find(input);
|
||||||
|
if (it == buffers_.end())
|
||||||
|
throw std::runtime_error(
|
||||||
|
"buffer address " +
|
||||||
|
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||||
|
" is not registered!");
|
||||||
|
ptrs = it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
size /= d;
|
||||||
|
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||||
|
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||||
|
#define KL(ngpus, name) \
|
||||||
|
name<T, ngpus> \
|
||||||
|
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
||||||
|
#define REDUCE_CASE(ngpus) \
|
||||||
|
case ngpus: { \
|
||||||
|
if (world_size_ == 2) { \
|
||||||
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
|
} else if (full_nvlink_) { \
|
||||||
|
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||||
|
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||||
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
|
} else { \
|
||||||
|
KL(ngpus, cross_device_reduce_2stage); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
KL(ngpus, cross_device_reduce_half_butterfly); \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (world_size_) {
|
||||||
|
REDUCE_CASE(2)
|
||||||
|
REDUCE_CASE(4)
|
||||||
|
REDUCE_CASE(6)
|
||||||
|
REDUCE_CASE(8)
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||||
|
"gpus = " +
|
||||||
|
std::to_string(world_size_));
|
||||||
|
}
|
||||||
|
#undef REDUCE_CASE
|
||||||
|
#undef KL
|
||||||
|
}
|
||||||
|
|
||||||
|
~CustomAllreduce() {
|
||||||
|
for (auto [_, ptr] : ipc_handles_) {
|
||||||
|
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||||
|
a template instantiation:
|
||||||
|
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
||||||
|
int, int, int);
|
||||||
|
*/
|
||||||
|
} // namespace vllm
|
||||||
284
csrc/custom_all_reduce_test.cu
Normal file
284
csrc/custom_all_reduce_test.cu
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
/**
|
||||||
|
* This is a standalone test for custom allreduce.
|
||||||
|
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||||
|
* export MPI_HOME=XXX
|
||||||
|
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||||
|
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||||
|
*
|
||||||
|
* Warning: this C++ test is not designed to be very readable and was used
|
||||||
|
* during the rapid prototyping process.
|
||||||
|
*
|
||||||
|
* To run:
|
||||||
|
* mpirun -np 8 ./custom_all_reduce_test
|
||||||
|
*/
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <curand_kernel.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "cuda_profiler_api.h"
|
||||||
|
#include "custom_all_reduce.cuh"
|
||||||
|
#include "mpi.h"
|
||||||
|
#include "nccl.h"
|
||||||
|
|
||||||
|
#define MPICHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
int e = cmd; \
|
||||||
|
if (e != MPI_SUCCESS) { \
|
||||||
|
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define NCCLCHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
__global__ void dummy_kernel() {
|
||||||
|
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void set_data(T *data, int size, int myRank) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
data[idx] = myRank * 0.11f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||||
|
double *fdata2, int size) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
fdata1[idx] = data1[idx];
|
||||||
|
fdata2[idx] = data2[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
for (int i = 0; i < nRanks; i++) {
|
||||||
|
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||||
|
int myRank, int nRanks, int size) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < nRanks; i++) {
|
||||||
|
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
||||||
|
T hval = val; // downcast first
|
||||||
|
sum += static_cast<double>(hval);
|
||||||
|
if (i == myRank) data[idx] = hval;
|
||||||
|
}
|
||||||
|
ground_truth[idx] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||||
|
int data_size) {
|
||||||
|
T *result;
|
||||||
|
cudaStream_t stream;
|
||||||
|
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||||
|
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||||
|
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
||||||
|
|
||||||
|
cudaIpcMemHandle_t self_data_handle;
|
||||||
|
cudaIpcMemHandle_t data_handles[8];
|
||||||
|
vllm::Metadata *buffer;
|
||||||
|
T *self_data_copy;
|
||||||
|
/**
|
||||||
|
* Allocate IPC buffer
|
||||||
|
*
|
||||||
|
* The first section is a temporary buffer for storing intermediate allreduce
|
||||||
|
* results, if a particular algorithm requires it. The second section is for
|
||||||
|
* the input to the allreduce. The actual API takes the input pointer as an
|
||||||
|
* argument (that is, they can and usually should be allocated separately).
|
||||||
|
* But since the input pointers and the temporary buffer all require IPC
|
||||||
|
* registration, they are allocated and registered together in the test for
|
||||||
|
* convenience.
|
||||||
|
*/
|
||||||
|
CUDACHECK(
|
||||||
|
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||||
|
CUDACHECK(cudaMemset(buffer, 0,
|
||||||
|
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||||
|
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||||
|
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
||||||
|
|
||||||
|
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
||||||
|
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||||
|
MPI_BYTE, MPI_COMM_WORLD));
|
||||||
|
|
||||||
|
void *rank_data;
|
||||||
|
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||||
|
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||||
|
std::vector<int64_t> offsets(nRanks, 0);
|
||||||
|
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||||
|
offsets, myRank);
|
||||||
|
auto *self_data =
|
||||||
|
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||||
|
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||||
|
// hack buffer registration
|
||||||
|
{
|
||||||
|
std::vector<std::string> handles;
|
||||||
|
handles.reserve(nRanks);
|
||||||
|
for (int i = 0; i < nRanks; i++) {
|
||||||
|
char *begin = (char *)&data_handles[i];
|
||||||
|
char *end = (char *)&data_handles[i + 1];
|
||||||
|
handles.emplace_back(begin, end);
|
||||||
|
}
|
||||||
|
std::vector<int64_t> offsets(
|
||||||
|
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||||
|
fa.register_buffer(handles, offsets, self_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
double *ground_truth;
|
||||||
|
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||||
|
curandState_t *states;
|
||||||
|
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||||
|
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||||
|
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||||
|
nRanks, data_size);
|
||||||
|
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
||||||
|
cudaMemcpyDeviceToDevice, stream));
|
||||||
|
cudaEvent_t start, stop;
|
||||||
|
CUDACHECK(cudaEventCreate(&start));
|
||||||
|
CUDACHECK(cudaEventCreate(&stop));
|
||||||
|
|
||||||
|
ncclDataType_t ncclDtype;
|
||||||
|
if (std::is_same<T, half>::value) {
|
||||||
|
ncclDtype = ncclFloat16;
|
||||||
|
} else if (std::is_same<T, nv_bfloat16>::value) {
|
||||||
|
ncclDtype = ncclBfloat16;
|
||||||
|
} else {
|
||||||
|
ncclDtype = ncclFloat;
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||||
|
constexpr int warmup_iters = 5;
|
||||||
|
constexpr int num_iters = 25;
|
||||||
|
// warmup
|
||||||
|
for (int i = 0; i < warmup_iters; i++) {
|
||||||
|
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||||
|
stream));
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(start, stream));
|
||||||
|
for (int i = 0; i < num_iters; i++) {
|
||||||
|
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||||
|
stream));
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(stop, stream));
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
float allreduce_ms = 0;
|
||||||
|
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||||
|
|
||||||
|
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
||||||
|
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
||||||
|
|
||||||
|
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||||
|
// warm up
|
||||||
|
for (int i = 0; i < warmup_iters; i++) {
|
||||||
|
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(start, stream));
|
||||||
|
for (int i = 0; i < num_iters; i++) {
|
||||||
|
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(stop, stream));
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
float duration_ms = 0;
|
||||||
|
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||||
|
if (myRank == 0)
|
||||||
|
printf(
|
||||||
|
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||||
|
"time:%.2fus\n",
|
||||||
|
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||||
|
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||||
|
|
||||||
|
// And wait for all the queued up work to complete
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||||
|
ncclSum, comm, stream));
|
||||||
|
|
||||||
|
double *nccl_result, *my_result;
|
||||||
|
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||||
|
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
||||||
|
|
||||||
|
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||||
|
my_result, data_size);
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
for (unsigned long j = 0; j < data_size; j++) {
|
||||||
|
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||||
|
if (diff >= 1e-2) {
|
||||||
|
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||||
|
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
long double nccl_diffs = 0.0;
|
||||||
|
long double my_diffs = 0.0;
|
||||||
|
for (int j = 0; j < data_size; j++) {
|
||||||
|
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||||
|
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||||
|
}
|
||||||
|
if (myRank == 0)
|
||||||
|
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||||
|
<< " me: " << my_diffs / data_size << std::endl;
|
||||||
|
|
||||||
|
CUDACHECK(cudaFree(result));
|
||||||
|
CUDACHECK(cudaFree(self_data_copy));
|
||||||
|
CUDACHECK(cudaFree(rank_data));
|
||||||
|
CUDACHECK(cudaFree(buffer));
|
||||||
|
CUDACHECK(cudaFree(states));
|
||||||
|
CUDACHECK(cudaFreeHost(ground_truth));
|
||||||
|
CUDACHECK(cudaFreeHost(nccl_result));
|
||||||
|
CUDACHECK(cudaFreeHost(my_result));
|
||||||
|
CUDACHECK(cudaStreamDestroy(stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
int nRanks, myRank;
|
||||||
|
MPICHECK(MPI_Init(&argc, &argv));
|
||||||
|
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||||
|
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||||
|
CUDACHECK(cudaSetDevice(myRank));
|
||||||
|
ncclUniqueId id;
|
||||||
|
ncclComm_t comm;
|
||||||
|
if (myRank == 0) ncclGetUniqueId(&id);
|
||||||
|
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||||
|
MPI_COMM_WORLD));
|
||||||
|
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||||
|
|
||||||
|
cudaProfilerStart();
|
||||||
|
// for (int threads : {256, 512}) {
|
||||||
|
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||||
|
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
||||||
|
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaProfilerStop();
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
@@ -14,3 +14,24 @@
|
|||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH( \
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|||||||
108
csrc/moe_align_block_size_kernels.cu
Normal file
108
csrc/moe_align_block_size_kernels.cu
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <THC/THCAtomics.cuh>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
const static size_t NUM_MAX_EXPERTS = 64;
|
||||||
|
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||||
|
int32_t *sorted_token_ids,
|
||||||
|
int32_t *expert_ids,
|
||||||
|
int32_t *total_tokens_post_pad,
|
||||||
|
int32_t num_experts,
|
||||||
|
int32_t block_size,
|
||||||
|
size_t numel) {
|
||||||
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
|
||||||
|
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
|
||||||
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
tokens_cnts[threadIdx.x + 1][i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||||
|
* which counts how many tokens in the token shard of thread_index are assigned
|
||||||
|
* to expert expert_index.
|
||||||
|
*/
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// For each expert we accumulate the token counts from the different threads.
|
||||||
|
tokens_cnts[0][threadIdx.x] = 0;
|
||||||
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
|
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// We accumulate the token counts of all experts in thread 0.
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cumsum[0] = 0;
|
||||||
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
|
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
|
||||||
|
}
|
||||||
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||||
|
* and stores the corresponding expert_id for each block.
|
||||||
|
*/
|
||||||
|
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||||
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each thread processes a token shard, calculating the index of each token after
|
||||||
|
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||||
|
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||||
|
* where * represents a padding value(preset in python).
|
||||||
|
*/
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
int32_t expert_id = topk_ids[i];
|
||||||
|
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||||
|
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||||
|
* stores the indices of the tokens processed by the expert with expert_id within
|
||||||
|
* the current thread's token shard.
|
||||||
|
*/
|
||||||
|
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
|
||||||
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
|
++tokens_cnts[threadIdx.x][expert_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void moe_align_block_size(
|
||||||
|
torch::Tensor topk_ids,
|
||||||
|
int num_experts,
|
||||||
|
int block_size,
|
||||||
|
torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad) {
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
assert(num_experts <= NUM_MAX_EXPERTS);
|
||||||
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
|
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
|
||||||
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
experts_ids.data_ptr<int32_t>(),
|
||||||
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
topk_ids.numel());
|
||||||
|
});
|
||||||
|
}
|
||||||
43
csrc/ops.h
43
csrc/ops.h
@@ -13,7 +13,8 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@@ -29,7 +30,8 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@@ -70,6 +72,14 @@ torch::Tensor awq_gemm(
|
|||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros,
|
torch::Tensor _zeros,
|
||||||
int split_k_iters);
|
int split_k_iters);
|
||||||
|
|
||||||
|
torch::Tensor awq_dequantize(
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters,
|
||||||
|
int thx,
|
||||||
|
int thy);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(
|
||||||
@@ -89,3 +99,32 @@ torch::Tensor gptq_gemm(
|
|||||||
void gptq_shuffle(
|
void gptq_shuffle(
|
||||||
torch::Tensor q_weight,
|
torch::Tensor q_weight,
|
||||||
torch::Tensor q_perm);
|
torch::Tensor q_perm);
|
||||||
|
|
||||||
|
void moe_align_block_size(
|
||||||
|
torch::Tensor topk_ids,
|
||||||
|
int num_experts,
|
||||||
|
int block_size,
|
||||||
|
torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
using fptr_t = uint64_t;
|
||||||
|
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets, int rank,
|
||||||
|
bool full_nvlink);
|
||||||
|
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||||
|
bool full_nvlink);
|
||||||
|
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||||
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||||
|
torch::Tensor &out);
|
||||||
|
void dispose(fptr_t _fa);
|
||||||
|
int meta_size();
|
||||||
|
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets);
|
||||||
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||||
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||||
|
const std::vector<std::vector<int64_t>> &offsets);
|
||||||
|
#endif
|
||||||
|
|||||||
217
csrc/punica/LICENSE
Normal file
217
csrc/punica/LICENSE
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
Contains code from https://github.com/punica-ai/punica
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This product bundles various third-party components under other open source licenses.
|
||||||
|
This section summarizes those components and their licenses. See licenses/
|
||||||
|
for text of these licenses.
|
||||||
|
|
||||||
|
|
||||||
|
Apache-2.0
|
||||||
|
* third_party/nvbench (with LLVM exception)
|
||||||
|
* third_party/flashinfer
|
||||||
|
|
||||||
|
BSD-3-Clause:
|
||||||
|
* third_party/cutlass
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
||||||
59
csrc/punica/bgmv/bgmv_config.h
Normal file
59
csrc/punica/bgmv/bgmv_config.h
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||||
|
typename W_T>
|
||||||
|
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1024) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1280) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1728) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1792) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2048) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2560) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2752) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3072) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3456) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3584) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 4096) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5120) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5504) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5632) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 6912) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 7168) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 8192) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 9216) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 10240) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 11008) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 12288) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 13824) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 14336) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 16384) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 20480) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 28672) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32768) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 33024) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 36864) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 49152) \
|
||||||
|
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||||
|
|
||||||
|
// Keep this in sync with vllm/config::LoRAConfig
|
||||||
|
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||||
|
|
||||||
|
// clang-format on
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
||||||
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cuda/pipeline>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "vec_dtypes.cuh"
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
// nthrs = (32, 4)
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||||
|
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||||
|
typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
size_t j = blockIdx.x;
|
||||||
|
constexpr size_t num_pipeline_stages = 2;
|
||||||
|
constexpr size_t tile_size = tx * ty * vec_size;
|
||||||
|
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||||
|
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||||
|
__shared__ float y_warpwise[ty];
|
||||||
|
|
||||||
|
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||||
|
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||||
|
auto pipe = cuda::make_pipeline();
|
||||||
|
|
||||||
|
// pipeline load W/X and compute WX;
|
||||||
|
pipe.producer_acquire();
|
||||||
|
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
W + (idx * feat_out + j) * feat_in +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||||
|
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
X + (batch_idx * feat_in) +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||||
|
pipe.producer_commit();
|
||||||
|
size_t copy_idx, compute_idx;
|
||||||
|
float y = 0.f;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
size_t tile_idx;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||||
|
++tile_idx) {
|
||||||
|
copy_idx = tile_idx % num_pipeline_stages;
|
||||||
|
// pipeline stage: async copy W fragment
|
||||||
|
pipe.producer_acquire();
|
||||||
|
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||||
|
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
W + (idx * feat_out + j) * feat_in +
|
||||||
|
tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||||
|
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||||
|
}
|
||||||
|
pipe.producer_commit();
|
||||||
|
|
||||||
|
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||||
|
// pipeline stage: compute WX
|
||||||
|
pipe.consumer_wait();
|
||||||
|
block.sync();
|
||||||
|
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||||
|
}
|
||||||
|
y_warpwise[threadIdx.y] = sum;
|
||||||
|
block.sync();
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
pipe.consumer_release();
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||||
|
// final pipeline stage
|
||||||
|
pipe.consumer_wait();
|
||||||
|
block.sync();
|
||||||
|
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||||
|
}
|
||||||
|
y_warpwise[threadIdx.y] =
|
||||||
|
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||||
|
? sum
|
||||||
|
: 0.f;
|
||||||
|
block.sync();
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
pipe.consumer_release();
|
||||||
|
|
||||||
|
// write Y;
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nthrs = (2, 16, 4)
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||||
|
typename in_T, typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
size_t tile_idx = blockIdx.x;
|
||||||
|
|
||||||
|
// load X;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||||
|
|
||||||
|
// load W;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||||
|
block.thread_rank() * vec_size);
|
||||||
|
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += g.shfl_down(sum, offset);
|
||||||
|
}
|
||||||
|
sum = g.shfl(sum, 0);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||||
|
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||||
|
typename W_T>
|
||||||
|
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale) {
|
||||||
|
constexpr size_t vec_size = 8;
|
||||||
|
constexpr int tz = 4;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if constexpr (feat_in < feat_out) {
|
||||||
|
static_assert(feat_in % vec_size == 0);
|
||||||
|
constexpr int tx = feat_in / vec_size;
|
||||||
|
|
||||||
|
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||||
|
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||||
|
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||||
|
|
||||||
|
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||||
|
constexpr int ty = 32 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||||
|
constexpr int ty = 16 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else {
|
||||||
|
constexpr int ty = 8 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||||
|
feat_in % (vec_size * 16) == 0 ||
|
||||||
|
feat_in % (vec_size * 8) == 0);
|
||||||
|
|
||||||
|
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||||
|
constexpr int tx = 32;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||||
|
vec_size * sizeof(W_T), tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||||
|
constexpr int tx = 32;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||||
|
vec_size * sizeof(in_T) / 2,
|
||||||
|
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||||
|
constexpr int tx = 16;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||||
|
vec_size * sizeof(in_T) / 2,
|
||||||
|
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||||
|
template void bgmv_kernel<feat_in, feat_out>( \
|
||||||
|
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||||
|
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||||
|
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||||
|
int64_t num_layers, int64_t layer_idx, float scale);
|
||||||
|
|
||||||
|
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||||
|
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||||
|
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
||||||
27
csrc/punica/bgmv/generator.py
Normal file
27
csrc/punica/bgmv/generator.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
DTYPES = ["fp16", "bf16", "fp32"]
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"fp16": "nv_half",
|
||||||
|
"bf16": "nv_bfloat16",
|
||||||
|
"fp32": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEMPLATE = """
|
||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
for input_dtype in DTYPES:
|
||||||
|
for output_dtype in DTYPES:
|
||||||
|
for weight_dtype in DTYPES:
|
||||||
|
if weight_dtype == "fp32":
|
||||||
|
# FP32 weights are not supported.
|
||||||
|
continue
|
||||||
|
kernel_definition = TEMPLATE.format(
|
||||||
|
input_dtype=DTYPE_MAP[input_dtype],
|
||||||
|
output_dtype=DTYPE_MAP[output_dtype],
|
||||||
|
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||||
|
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(kernel_definition)
|
||||||
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
File diff suppressed because it is too large
Load Diff
563
csrc/punica/punica_ops.cc
Normal file
563
csrc/punica/punica_ops.cc
Normal file
@@ -0,0 +1,563 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "bgmv/bgmv_config.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//====== utils ======
|
||||||
|
|
||||||
|
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||||
|
const char *a_name, const char *b_name) {
|
||||||
|
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||||
|
a.dim(), " vs ", b.dim());
|
||||||
|
for (int i = 0; i < a.dim(); ++i) {
|
||||||
|
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||||
|
".size(", i, ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||||
|
return (uint32_t(a) << 16) | uint32_t(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
|
||||||
|
#define CHECK_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
#define CHECK_DIM(d, x) \
|
||||||
|
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||||
|
|
||||||
|
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||||
|
|
||||||
|
#define CHECK_EQ(a, b) \
|
||||||
|
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||||
|
|
||||||
|
//====== bgmv ======
|
||||||
|
|
||||||
|
template <typename in_T, typename out_T, typename W_T>
|
||||||
|
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||||
|
const int64_t *lora_indices,
|
||||||
|
uint16_t in_features, uint16_t out_features,
|
||||||
|
int64_t y_offset, int64_t full_y_size,
|
||||||
|
int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale) {
|
||||||
|
switch (pack_u16(in_features, out_features)) {
|
||||||
|
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||||
|
case pack_u16(feat_in, feat_out): \
|
||||||
|
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||||
|
full_y_size, batch_size, num_layers, \
|
||||||
|
layer_idx, scale); \
|
||||||
|
break;
|
||||||
|
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||||
|
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||||
|
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||||
|
#undef CASE
|
||||||
|
#undef CASE_ONESIDE
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
||||||
|
CHECK_INPUT(y);
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(w);
|
||||||
|
CHECK_INPUT(indicies);
|
||||||
|
|
||||||
|
CHECK_DIM(2, y);
|
||||||
|
CHECK_DIM(2, x);
|
||||||
|
CHECK_DIM(4, w);
|
||||||
|
CHECK_DIM(1, indicies);
|
||||||
|
|
||||||
|
int64_t B = x.size(0);
|
||||||
|
int64_t h_in = x.size(1);
|
||||||
|
int64_t h_out = y.size(1);
|
||||||
|
int64_t num_layers = w.size(1);
|
||||||
|
CHECK_EQ(w.size(3), h_in);
|
||||||
|
CHECK_EQ(w.size(2), h_out);
|
||||||
|
CHECK_EQ(indicies.size(0), x.size(0));
|
||||||
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
|
bool ok = false;
|
||||||
|
if (h_in < 65536 && h_out < 65536) {
|
||||||
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
|
switch (x.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
|
float scale, int64_t h_in, int64_t h_out,
|
||||||
|
int64_t y_offset) {
|
||||||
|
CHECK_INPUT(y);
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(w);
|
||||||
|
CHECK_INPUT(indicies);
|
||||||
|
|
||||||
|
CHECK_DIM(2, y);
|
||||||
|
CHECK_DIM(2, x);
|
||||||
|
CHECK_DIM(4, w);
|
||||||
|
CHECK_DIM(1, indicies);
|
||||||
|
|
||||||
|
int64_t B = x.size(0);
|
||||||
|
int64_t num_layers = w.size(1);
|
||||||
|
int64_t full_y_size = y.size(1);
|
||||||
|
CHECK_EQ(w.size(3), h_in);
|
||||||
|
CHECK_EQ(w.size(2), h_out);
|
||||||
|
CHECK_EQ(indicies.size(0), x.size(0));
|
||||||
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
|
bool ok = false;
|
||||||
|
if (h_in < 65536 && h_out < 65536) {
|
||||||
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
|
switch (x.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//====== pybind ======
|
||||||
|
|
||||||
|
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||||
|
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||||
|
"dispatch_bgmv_low_level");
|
||||||
|
}
|
||||||
@@ -51,10 +51,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
#endif
|
#endif
|
||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
ops.def(
|
||||||
|
"moe_align_block_size",
|
||||||
|
&moe_align_block_size,
|
||||||
|
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
@@ -74,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
"gather_cached_kv",
|
"gather_cached_kv",
|
||||||
&gather_cached_kv,
|
&gather_cached_kv,
|
||||||
"Gather key and value from the cache into contiguous QKV tensors");
|
"Gather key and value from the cache into contiguous QKV tensors");
|
||||||
|
cache_ops.def(
|
||||||
|
"convert_fp8_e5m2",
|
||||||
|
&convert_fp8_e5m2,
|
||||||
|
"Convert the key and value cache to fp8_e5m2 data type");
|
||||||
|
|
||||||
// Cuda utils
|
// Cuda utils
|
||||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
@@ -81,4 +90,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
"get_device_attribute",
|
"get_device_attribute",
|
||||||
&get_device_attribute,
|
&get_device_attribute,
|
||||||
"Gets the specified device attribute.");
|
"Gets the specified device attribute.");
|
||||||
|
|
||||||
|
cuda_utils.def(
|
||||||
|
"get_max_shared_memory_per_block_device_attribute",
|
||||||
|
&get_max_shared_memory_per_block_device_attribute,
|
||||||
|
"Gets the maximum shared memory per block device attribute.");
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// Custom all-reduce kernels
|
||||||
|
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||||
|
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||||
|
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||||
|
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||||
|
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||||
|
custom_ar.def("dispose", &dispose, "dispose");
|
||||||
|
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||||
|
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||||
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||||
|
"get_graph_buffer_ipc_meta");
|
||||||
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||||
|
"register_graph_buffers");
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||||
|
int* __restrict__ B,
|
||||||
|
half* __restrict__ scaling_factors,
|
||||||
|
int* __restrict__ zeros,
|
||||||
|
half* __restrict__ C,
|
||||||
|
int G
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int j_factors1 = 4;
|
||||||
|
int row_stride2 = 4;
|
||||||
|
int split_k_iters = 1;
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
half B_shared[32 * (128 + 8)];
|
||||||
|
|
||||||
|
half* B_shared_ptr2 = B_shared;
|
||||||
|
|
||||||
|
half B_shared_warp[32];
|
||||||
|
int OC = 512;
|
||||||
|
|
||||||
|
int N = blockDim.x * gridDim.x; // 2
|
||||||
|
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||||
|
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
int index1 = 8 * col + 8 * row * N;
|
||||||
|
half* C_ptr2 = C + index1;
|
||||||
|
|
||||||
|
int index2 = col + row * N;
|
||||||
|
int* B_ptr2 = B + index2;
|
||||||
|
|
||||||
|
int index3 = col + (int)(row / G) * N;
|
||||||
|
int* zeros_ptr2 = zeros + index3;
|
||||||
|
int index4 = 8 * col + (int)(row / G) * N * 8;
|
||||||
|
half* scaling_factors_ptr2 = scaling_factors + index4;
|
||||||
|
|
||||||
|
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
|
||||||
|
int j=0;
|
||||||
|
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
|
||||||
|
*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
|
||||||
|
|
||||||
|
for (int i=0; i<8; ++i) {
|
||||||
|
*(C_ptr2 + i) = B_shared[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace awq
|
} // namespace awq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
torch::Tensor awq_dequantize(
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters,
|
||||||
|
int thx,
|
||||||
|
int thy)
|
||||||
|
{
|
||||||
|
int in_c = _kernel.size(0);
|
||||||
|
int qout_c = _kernel.size(1);
|
||||||
|
int out_c = qout_c * 8;
|
||||||
|
int G = in_c / _scaling_factors.size(0);
|
||||||
|
|
||||||
|
int x_thread = thx;
|
||||||
|
int y_thread = thy;
|
||||||
|
|
||||||
|
int x_blocks = 1;
|
||||||
|
int y_blocks = 1;
|
||||||
|
if (thx==0) {
|
||||||
|
x_thread = qout_c;
|
||||||
|
}
|
||||||
|
if (thy==0) {
|
||||||
|
y_thread = in_c;
|
||||||
|
}
|
||||||
|
if (thx==0 && thy==0) {
|
||||||
|
x_thread = 8;
|
||||||
|
y_thread = 8;
|
||||||
|
x_blocks = (int)(qout_c / 8);
|
||||||
|
y_blocks = (int)(in_c / 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||||
|
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||||
|
|
||||||
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
|
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||||
|
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
|
||||||
|
dim3 num_blocks(x_blocks, y_blocks);
|
||||||
|
dim3 threads_per_block(x_thread, y_thread);
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
kernel, scaling_factors, zeros, de_kernel, G);
|
||||||
|
|
||||||
|
return _de_kernel;
|
||||||
|
}
|
||||||
|
|
||||||
// in_feats: M, IC [float16]
|
// in_feats: M, IC [float16]
|
||||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||||
// scaling_factors: IC // G, OC [float16]
|
// scaling_factors: IC // G, OC [float16]
|
||||||
|
|||||||
278
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
278
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <type_traits>
|
||||||
|
#include "../../attention/attention_dtypes.h"
|
||||||
|
#include "../../attention/dtype_float32.cuh"
|
||||||
|
#include "../../attention/dtype_float16.cuh"
|
||||||
|
#include "../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
namespace fp8_e5m2_unscaled {
|
||||||
|
|
||||||
|
template<typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||||
|
{
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||||
|
tmp.u16[0] = res.x;
|
||||||
|
tmp.u16[1] = res.y;
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||||
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
// fp8 -> half
|
||||||
|
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||||
|
{
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||||
|
{
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||||
|
__nv_bfloat162 b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||||
|
bf16_4_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||||
|
bf16_8_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fp8_e5m2_unscaled
|
||||||
|
#endif // ENABLE_FP8_E5M2
|
||||||
|
} // namespace vllm
|
||||||
@@ -9,11 +9,15 @@
|
|||||||
# If extensions (or modules to document with autodoc) are in another directory,
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
# add these directories to sys.path here. If the directory is relative to the
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
#
|
|
||||||
# import os
|
|
||||||
# import sys
|
|
||||||
# sys.path.insert(0, os.path.abspath('.'))
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from sphinx.ext import autodoc
|
||||||
|
import logging
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
@@ -21,7 +25,6 @@ project = 'vLLM'
|
|||||||
copyright = '2023, vLLM Team'
|
copyright = '2023, vLLM Team'
|
||||||
author = 'the vLLM Team'
|
author = 'the vLLM Team'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
# Add any Sphinx extension module names here, as strings. They can be
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
@@ -32,6 +35,8 @@ extensions = [
|
|||||||
"sphinx.ext.viewcode",
|
"sphinx.ext.viewcode",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
"sphinx_copybutton",
|
"sphinx_copybutton",
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
"sphinx.ext.autosummary",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
@@ -55,7 +60,6 @@ html_title = project
|
|||||||
html_theme = 'sphinx_book_theme'
|
html_theme = 'sphinx_book_theme'
|
||||||
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
||||||
html_theme_options = {
|
html_theme_options = {
|
||||||
'logo_only': True,
|
|
||||||
'path_to_docs': 'docs/source',
|
'path_to_docs': 'docs/source',
|
||||||
'repository_url': 'https://github.com/vllm-project/vllm',
|
'repository_url': 'https://github.com/vllm-project/vllm',
|
||||||
'use_repository_button': True,
|
'use_repository_button': True,
|
||||||
@@ -64,4 +68,29 @@ html_theme_options = {
|
|||||||
# Add any paths that contain custom static files (such as style sheets) here,
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
# relative to this directory. They are copied after the builtin static files,
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
html_static_path = ['_static']
|
# html_static_path = ['_static']
|
||||||
|
|
||||||
|
# Mock out external dependencies here.
|
||||||
|
autodoc_mock_imports = [
|
||||||
|
"torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
|
||||||
|
"vllm.cuda_utils", "vllm._C"
|
||||||
|
]
|
||||||
|
|
||||||
|
for mock_target in autodoc_mock_imports:
|
||||||
|
if mock_target in sys.modules:
|
||||||
|
logger.info(
|
||||||
|
f"Potentially problematic mock target ({mock_target}) found; "
|
||||||
|
"autodoc_mock_imports cannot mock modules that have already "
|
||||||
|
"been loaded into sys.modules when the sphinx build starts.")
|
||||||
|
|
||||||
|
|
||||||
|
class MockedClassDocumenter(autodoc.ClassDocumenter):
|
||||||
|
"""Remove note about base class when a class is derived from object."""
|
||||||
|
|
||||||
|
def add_line(self, line: str, source: str, *lineno: int) -> None:
|
||||||
|
if line == " Bases: :py:class:`object`":
|
||||||
|
return
|
||||||
|
super().add_line(line, source, *lineno)
|
||||||
|
|
||||||
|
|
||||||
|
autodoc.ClassDocumenter = MockedClassDocumenter
|
||||||
|
|||||||
7
docs/source/dev/engine/async_llm_engine.rst
Normal file
7
docs/source/dev/engine/async_llm_engine.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
|
||||||
|
AsyncLLMEngine
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
|
||||||
|
:members: generate, abort
|
||||||
|
:show-inheritance:
|
||||||
13
docs/source/dev/engine/engine_index.rst
Normal file
13
docs/source/dev/engine/engine_index.rst
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
vLLM Engine
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. automodule:: vllm.engine
|
||||||
|
.. currentmodule:: vllm.engine
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Engines
|
||||||
|
|
||||||
|
llm_engine
|
||||||
|
async_llm_engine
|
||||||
|
|
||||||
6
docs/source/dev/engine/llm_engine.rst
Normal file
6
docs/source/dev/engine/llm_engine.rst
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
LLMEngine
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. autoclass:: vllm.engine.llm_engine.LLMEngine
|
||||||
|
:members: add_request, abort_request, step, _init_cache
|
||||||
|
:show-inheritance:
|
||||||
@@ -11,10 +11,10 @@ Requirements
|
|||||||
------------
|
------------
|
||||||
|
|
||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Python: 3.8 -- 3.11 (Verified on 3.10)
|
* Python: 3.8 -- 3.11
|
||||||
* GPU: MI200s
|
* GPU: MI200s (gfx90a), MI300 (gfx942)
|
||||||
* Pytorch 2.0.1/2.1.1/2.2
|
* Pytorch 2.0.1/2.1.1/2.2
|
||||||
* ROCm 5.7
|
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
|
||||||
|
|
||||||
Installation options:
|
Installation options:
|
||||||
|
|
||||||
@@ -27,6 +27,8 @@ Installation options:
|
|||||||
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
||||||
---------------------------------------------------------------------------
|
---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This option is for ROCm 5.7 only:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
|
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
|
||||||
@@ -50,6 +52,9 @@ Option 2: Build from source
|
|||||||
|
|
||||||
You can build and install vLLM from source:
|
You can build and install vLLM from source:
|
||||||
|
|
||||||
|
Below instruction is for ROCm 5.7 only.
|
||||||
|
At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
|
||||||
|
|
||||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||||
|
|
||||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||||
@@ -95,6 +100,23 @@ You can build and install vLLM from source:
|
|||||||
|
|
||||||
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
||||||
|
|
||||||
|
The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
|
||||||
|
|
||||||
|
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
|
||||||
|
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||||
|
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
|
||||||
|
|
||||||
|
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
||||||
|
|
||||||
|
For example, to build docker image for vllm on ROCm 5.7, you can run:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
|
||||||
|
-f Dockerfile.rocm -t vllm-rocm .
|
||||||
|
|
||||||
|
To build vllm on ROCm 6.0, you can use the default:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||||
@@ -142,3 +164,8 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
|
|||||||
$ cd vllm
|
$ cd vllm
|
||||||
$ pip install -U -r requirements-rocm.txt
|
$ pip install -U -r requirements-rocm.txt
|
||||||
$ python setup.py install # This may take 5-10 minutes.
|
$ python setup.py install # This may take 5-10 minutes.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,14 @@ This guide shows how to use vLLM to:
|
|||||||
|
|
||||||
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
|
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
By default, vLLM downloads model from `HuggingFace <https://huggingface.co/>`_. If you would like to use models from `ModelScope <https://www.modelscope.cn>`_ in the following examples, please set the environment variable:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export VLLM_USE_MODELSCOPE=True
|
||||||
|
|
||||||
Offline Batched Inference
|
Offline Batched Inference
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
@@ -40,16 +48,6 @@ Initialize vLLM's engine for offline inference with the ``LLM`` class and the `O
|
|||||||
|
|
||||||
llm = LLM(model="facebook/opt-125m")
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
|
||||||
Use model from www.modelscope.cn
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
export VLLM_USE_MODELSCOPE=True
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
llm = LLM(model="qwen/Qwen-7B-Chat", revision="v1.1.8", trust_remote_code=True)
|
|
||||||
|
|
||||||
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
|
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -65,49 +63,11 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM
|
|||||||
|
|
||||||
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
|
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
|
||||||
|
|
||||||
|
|
||||||
API Server
|
|
||||||
----------
|
|
||||||
|
|
||||||
vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.
|
|
||||||
|
|
||||||
Start the server:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ python -m vllm.entrypoints.api_server
|
|
||||||
|
|
||||||
Use model from www.modelscope.cn
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.api_server \
|
|
||||||
$ --model="qwen/Qwen-7B-Chat" \
|
|
||||||
$ --revision="v1.1.8" \
|
|
||||||
$ --trust-remote-code
|
|
||||||
|
|
||||||
|
|
||||||
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
|
|
||||||
|
|
||||||
Query the model in shell:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ curl http://localhost:8000/generate \
|
|
||||||
$ -d '{
|
|
||||||
$ "prompt": "San Francisco is a",
|
|
||||||
$ "use_beam_search": true,
|
|
||||||
$ "n": 4,
|
|
||||||
$ "temperature": 0
|
|
||||||
$ }'
|
|
||||||
|
|
||||||
See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.
|
|
||||||
|
|
||||||
OpenAI-Compatible Server
|
OpenAI-Compatible Server
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
||||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||||
|
|
||||||
Start the server:
|
Start the server:
|
||||||
|
|
||||||
@@ -116,13 +76,6 @@ Start the server:
|
|||||||
$ python -m vllm.entrypoints.openai.api_server \
|
$ python -m vllm.entrypoints.openai.api_server \
|
||||||
$ --model facebook/opt-125m
|
$ --model facebook/opt-125m
|
||||||
|
|
||||||
Use model from www.modelscope.cn
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
|
|
||||||
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
|
|
||||||
|
|
||||||
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
@@ -137,6 +90,8 @@ This server can be queried in the same format as OpenAI API. For example, list t
|
|||||||
|
|
||||||
$ curl http://localhost:8000/v1/models
|
$ curl http://localhost:8000/v1/models
|
||||||
|
|
||||||
|
You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header.
|
||||||
|
|
||||||
Using OpenAI Completions API with vLLM
|
Using OpenAI Completions API with vLLM
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ vLLM is fast with:
|
|||||||
* Efficient management of attention key and value memory with **PagedAttention**
|
* Efficient management of attention key and value memory with **PagedAttention**
|
||||||
* Continuous batching of incoming requests
|
* Continuous batching of incoming requests
|
||||||
* Fast model execution with CUDA/HIP graph
|
* Fast model execution with CUDA/HIP graph
|
||||||
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_
|
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_, FP8 KV Cache
|
||||||
* Optimized CUDA kernels
|
* Optimized CUDA kernels
|
||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
@@ -42,6 +42,8 @@ vLLM is flexible and easy to use with:
|
|||||||
* Streaming outputs
|
* Streaming outputs
|
||||||
* OpenAI-compatible API server
|
* OpenAI-compatible API server
|
||||||
* Support NVIDIA GPUs and AMD GPUs
|
* Support NVIDIA GPUs and AMD GPUs
|
||||||
|
* (Experimental) Prefix caching support
|
||||||
|
* (Experimental) Multi-lora support
|
||||||
|
|
||||||
For more information, check out the following:
|
For more information, check out the following:
|
||||||
|
|
||||||
@@ -86,3 +88,15 @@ Documentation
|
|||||||
:caption: Quantization
|
:caption: Quantization
|
||||||
|
|
||||||
quantization/auto_awq
|
quantization/auto_awq
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Developer Documentation
|
||||||
|
|
||||||
|
dev/engine/engine_index
|
||||||
|
|
||||||
|
Indices and tables
|
||||||
|
==================
|
||||||
|
|
||||||
|
* :ref:`genindex`
|
||||||
|
* :ref:`modindex`
|
||||||
|
|||||||
@@ -68,6 +68,12 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - :code:`QWenLMHeadModel`
|
* - :code:`QWenLMHeadModel`
|
||||||
- Qwen
|
- Qwen
|
||||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
|
* - :code:`Qwen2ForCausalLM`
|
||||||
|
- Qwen2
|
||||||
|
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
|
||||||
|
* - :code:`StableLMEpochForCausalLM`
|
||||||
|
- StableLM
|
||||||
|
- :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
|
||||||
* - :code:`YiForCausalLM`
|
* - :code:`YiForCausalLM`
|
||||||
- Yi
|
- Yi
|
||||||
- :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
- :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||||
|
|||||||
32
docs/source/quantization/fp8_e5m2_kv_cache.rst
Normal file
32
docs/source/quantization/fp8_e5m2_kv_cache.rst
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
.. _fp8_e5m2_kv_cache:
|
||||||
|
|
||||||
|
FP8 E5M2 KV Cache
|
||||||
|
==================
|
||||||
|
|
||||||
|
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
|
||||||
|
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other.
|
||||||
|
|
||||||
|
Here is an example of how to enable this feature:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
81
examples/gradio_openai_chatbot_webserver.py
Normal file
81
examples/gradio_openai_chatbot_webserver.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import argparse
|
||||||
|
from openai import OpenAI
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
# Argument parser setup
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Chatbot Interface with Customizable Parameters')
|
||||||
|
parser.add_argument('--model-url',
|
||||||
|
type=str,
|
||||||
|
default='http://localhost:8000/v1',
|
||||||
|
help='Model URL')
|
||||||
|
parser.add_argument('-m',
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='Model name for the chatbot')
|
||||||
|
parser.add_argument('--temp',
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help='Temperature for text generation')
|
||||||
|
parser.add_argument('--stop-token-ids',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help='Comma-separated stop token IDs')
|
||||||
|
parser.add_argument("--host", type=str, default=None)
|
||||||
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
|
|
||||||
|
# Parse the arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = args.model_url
|
||||||
|
|
||||||
|
# Create an OpenAI client to interact with the API server
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def predict(message, history):
|
||||||
|
# Convert chat history to OpenAI format
|
||||||
|
history_openai_format = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a great ai assistant."
|
||||||
|
}]
|
||||||
|
for human, assistant in history:
|
||||||
|
history_openai_format.append({"role": "user", "content": human})
|
||||||
|
history_openai_format.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": assistant
|
||||||
|
})
|
||||||
|
history_openai_format.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Create a chat completion request and send it to the API server
|
||||||
|
stream = client.chat.completions.create(
|
||||||
|
model=args.model, # Model name to use
|
||||||
|
messages=history_openai_format, # Chat history
|
||||||
|
temperature=args.temp, # Temperature for text generation
|
||||||
|
stream=True, # Stream response
|
||||||
|
extra_body={
|
||||||
|
'repetition_penalty':
|
||||||
|
1,
|
||||||
|
'stop_token_ids': [
|
||||||
|
int(id.strip()) for id in args.stop_token_ids.split(',')
|
||||||
|
if id.strip()
|
||||||
|
] if args.stop_token_ids else []
|
||||||
|
})
|
||||||
|
|
||||||
|
# Read and return generated text from response stream
|
||||||
|
partial_message = ""
|
||||||
|
for chunk in stream:
|
||||||
|
partial_message += (chunk.choices[0].delta.content or "")
|
||||||
|
yield partial_message
|
||||||
|
|
||||||
|
|
||||||
|
# Create and launch a chat interface with Gradio
|
||||||
|
gr.ChatInterface(predict).queue().launch(server_name=args.host,
|
||||||
|
server_port=args.port,
|
||||||
|
share=True)
|
||||||
117
examples/multilora_inference.py
Normal file
117
examples/multilora_inference.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
This example shows how to use the multi-LoRA functionality for offline inference.
|
||||||
|
|
||||||
|
Requires HuggingFace credentials for access to Llama2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
|
||||||
|
"""Create a list of test prompts with their sampling parameters.
|
||||||
|
|
||||||
|
2 requests for base model, 4 requests for the LoRA. We define 2
|
||||||
|
different LoRA adapters (using the same model for demo purposes).
|
||||||
|
Since we also set `max_loras=1`, the expectation is that the requests
|
||||||
|
with the second LoRA adapter will be ran after all requests with the
|
||||||
|
first adapter have finished.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
("A robot may not injure a human being",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128), None),
|
||||||
|
("To be or not to be,",
|
||||||
|
SamplingParams(temperature=0.8,
|
||||||
|
top_k=5,
|
||||||
|
presence_penalty=0.2,
|
||||||
|
max_tokens=128), None),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora", 1, lora_path)),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
|
||||||
|
SamplingParams(n=3,
|
||||||
|
best_of=3,
|
||||||
|
use_beam_search=True,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora", 1, lora_path)),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora2", 2, lora_path)),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
|
||||||
|
SamplingParams(n=3,
|
||||||
|
best_of=3,
|
||||||
|
use_beam_search=True,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora", 1, lora_path)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_requests(engine: LLMEngine,
|
||||||
|
test_prompts: List[Tuple[str, SamplingParams,
|
||||||
|
Optional[LoRARequest]]]):
|
||||||
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
|
request_id = 0
|
||||||
|
|
||||||
|
while test_prompts or engine.has_unfinished_requests():
|
||||||
|
if test_prompts:
|
||||||
|
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||||
|
engine.add_request(str(request_id),
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=lora_request)
|
||||||
|
request_id += 1
|
||||||
|
|
||||||
|
request_outputs: List[RequestOutput] = engine.step()
|
||||||
|
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
print(request_output)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_engine() -> LLMEngine:
|
||||||
|
"""Initialize the LLMEngine."""
|
||||||
|
# max_loras: controls the number of LoRAs that can be used in the same
|
||||||
|
# batch. Larger numbers will cause higher memory usage, as each LoRA
|
||||||
|
# slot requires its own preallocated tensor.
|
||||||
|
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
|
||||||
|
# numbers will cause higher memory usage. If you know that all LoRAs will
|
||||||
|
# use the same rank, it is recommended to set this as low as possible.
|
||||||
|
# max_cpu_loras: controls the size of the CPU LoRA cache.
|
||||||
|
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_lora_rank=8,
|
||||||
|
max_cpu_loras=2,
|
||||||
|
max_num_seqs=256)
|
||||||
|
return LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function that sets up and runs the prompt processing."""
|
||||||
|
engine = initialize_engine()
|
||||||
|
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
|
test_prompts = create_test_prompts(lora_path)
|
||||||
|
process_requests(engine, test_prompts)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
59
examples/offline_inference_with_prefix.py
Normal file
59
examples/offline_inference_with_prefix.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prefix = (
|
||||||
|
"You are an expert school principal, skilled in effectively managing "
|
||||||
|
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||||
|
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||||
|
"community, joyful discovery, and life-long learning. The candidate is "
|
||||||
|
"coming in for a first-round panel interview for a 8th grade Math "
|
||||||
|
"teaching role. They have 5 years of previous teaching experience "
|
||||||
|
"as an assistant teacher at a co-ed, public school with experience "
|
||||||
|
"in middle school math teaching. Based on these information, fulfill "
|
||||||
|
"the following paragraph: ")
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.0)
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
|
||||||
|
generating_prompts = [prefix + prompt for prompt in prompts]
|
||||||
|
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(generating_prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# -1 since the last token can change when concatenating prompts.
|
||||||
|
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
||||||
|
|
||||||
|
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
|
||||||
|
# The prefix will only be cached after the first batch is processed, so we need to call generate once
|
||||||
|
# to calculate the prefix and cache it.
|
||||||
|
outputs = llm.generate(generating_prompts[0],
|
||||||
|
sampling_params,
|
||||||
|
prefix_pos=[prefix_pos])
|
||||||
|
|
||||||
|
# Subsequent batches can leverage the cached prefix
|
||||||
|
outputs = llm.generate(generating_prompts,
|
||||||
|
sampling_params,
|
||||||
|
prefix_pos=[prefix_pos] * len(generating_prompts))
|
||||||
|
|
||||||
|
# Print the outputs. You should see the same outputs as before
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
@@ -32,6 +32,5 @@ chat_completion = client.chat.completions.create(
|
|||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
print("Chat completion results:")
|
print("Chat completion results:")
|
||||||
print(chat_completion)
|
print(chat_completion)
|
||||||
|
|||||||
@@ -21,8 +21,7 @@ completion = client.completions.create(
|
|||||||
echo=False,
|
echo=False,
|
||||||
n=2,
|
n=2,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=3
|
logprobs=3)
|
||||||
)
|
|
||||||
|
|
||||||
print("Completion results:")
|
print("Completion results:")
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
22
examples/template_baichuan.jinja
Normal file
22
examples/template_baichuan.jinja
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||||
|
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if message['role'] == 'user' %}
|
||||||
|
<reserved_106>
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
<reserved_107>
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||||
|
<reserved_107>
|
||||||
|
{% endif %}
|
||||||
@@ -71,7 +71,7 @@ format_changed() {
|
|||||||
|
|
||||||
# Format all files
|
# Format all files
|
||||||
format_all() {
|
format_all() {
|
||||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
|
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
|
||||||
}
|
}
|
||||||
|
|
||||||
## This flag formats individual files. --files *must* be the first command line
|
## This flag formats individual files. --files *must* be the first command line
|
||||||
|
|||||||
@@ -13,4 +13,9 @@ types-setuptools
|
|||||||
pytest
|
pytest
|
||||||
pytest-forked
|
pytest-forked
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
|
httpx
|
||||||
|
einops # required for MPT
|
||||||
|
flash_attn # required for HuggingFace's llama implementation
|
||||||
|
openai
|
||||||
|
requests
|
||||||
|
ray
|
||||||
9
requirements-neuron.txt
Normal file
9
requirements-neuron.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
|
numpy
|
||||||
|
transformers-neuronx >= 0.9.0
|
||||||
|
torch-neuronx >= 2.1.0
|
||||||
|
neuronx-cc
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
|
aioprometheus[starlette]
|
||||||
@@ -2,12 +2,12 @@ ninja # For faster builds.
|
|||||||
typing-extensions>=4.8.0
|
typing-extensions>=4.8.0
|
||||||
starlette
|
starlette
|
||||||
psutil
|
psutil
|
||||||
ray >= 2.5.1
|
ray >= 2.9
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
tokenizers>=0.15.0
|
tokenizers>=0.15.0
|
||||||
transformers >= 4.36.0 # Required for Mixtral.
|
transformers >= 4.37.0 # Required for Mixtral.
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic == 1.10.13 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
aioprometheus[starlette]
|
aioprometheus[starlette]
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
ninja # For faster builds.
|
ninja # For faster builds.
|
||||||
psutil
|
psutil
|
||||||
ray >= 2.5.1
|
ray >= 2.9
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
torch == 2.1.2
|
torch == 2.1.2
|
||||||
transformers >= 4.36.0 # Required for Mixtral.
|
transformers >= 4.37.0 # Required for Qwen2
|
||||||
xformers == 0.0.23.post1 # Required for CUDA 12.1.
|
xformers == 0.0.23.post1 # Required for CUDA 12.1.
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic == 1.10.13 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
aioprometheus[starlette]
|
aioprometheus[starlette]
|
||||||
|
pynvml == 11.5.0
|
||||||
|
|||||||
146
setup.py
146
setup.py
@@ -1,13 +1,16 @@
|
|||||||
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Set
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Set
|
||||||
|
|
||||||
from packaging.version import parse, Version
|
from packaging.version import parse, Version
|
||||||
import setuptools
|
import setuptools
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.cpp_extension as torch_cpp_ext
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(__file__)
|
ROOT_DIR = os.path.dirname(__file__)
|
||||||
@@ -24,8 +27,17 @@ def _is_hip() -> bool:
|
|||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_neuron() -> bool:
|
||||||
|
torch_neuronx_installed = True
|
||||||
|
try:
|
||||||
|
subprocess.run(["neuron-ls"], capture_output=True, check=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
torch_neuronx_installed = False
|
||||||
|
return torch_neuronx_installed
|
||||||
|
|
||||||
|
|
||||||
def _is_cuda() -> bool:
|
def _is_cuda() -> bool:
|
||||||
return torch.version.cuda is not None
|
return (torch.version.cuda is not None) and not _is_neuron()
|
||||||
|
|
||||||
|
|
||||||
# Compiler flags.
|
# Compiler flags.
|
||||||
@@ -39,6 +51,8 @@ if _is_hip():
|
|||||||
"Cannot find ROCM_HOME. ROCm must be available to build the package."
|
"Cannot find ROCM_HOME. ROCm must be available to build the package."
|
||||||
)
|
)
|
||||||
NVCC_FLAGS += ["-DUSE_ROCM"]
|
NVCC_FLAGS += ["-DUSE_ROCM"]
|
||||||
|
NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"]
|
||||||
|
NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"]
|
||||||
|
|
||||||
if _is_cuda() and CUDA_HOME is None:
|
if _is_cuda() and CUDA_HOME is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -87,6 +101,30 @@ def get_hipcc_rocm_version():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def glob(pattern: str):
|
||||||
|
root = Path(__name__).parent
|
||||||
|
return [str(p) for p in root.glob(pattern)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_neuronxcc_version():
|
||||||
|
import sysconfig
|
||||||
|
site_dir = sysconfig.get_paths()["purelib"]
|
||||||
|
version_file = os.path.join(site_dir, "neuronxcc", "version",
|
||||||
|
"__init__.py")
|
||||||
|
|
||||||
|
# Check if the command was executed successfully
|
||||||
|
with open(version_file, "rt") as fp:
|
||||||
|
content = fp.read()
|
||||||
|
|
||||||
|
# Extract the version using a regular expression
|
||||||
|
match = re.search(r"__version__ = '(\S+)'", content)
|
||||||
|
if match:
|
||||||
|
# Return the version string
|
||||||
|
return match.group(1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not find HIP version in the output")
|
||||||
|
|
||||||
|
|
||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||||
"""Get the CUDA version from nvcc.
|
"""Get the CUDA version from nvcc.
|
||||||
|
|
||||||
@@ -151,6 +189,8 @@ if _is_cuda() and not compute_capabilities:
|
|||||||
"GPUs with compute capability below 7.0 are not supported.")
|
"GPUs with compute capability below 7.0 are not supported.")
|
||||||
compute_capabilities.add(f"{major}.{minor}")
|
compute_capabilities.add(f"{major}.{minor}")
|
||||||
|
|
||||||
|
ext_modules = []
|
||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||||
if not compute_capabilities:
|
if not compute_capabilities:
|
||||||
@@ -188,6 +228,8 @@ if _is_cuda():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||||
|
|
||||||
|
NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()
|
||||||
|
|
||||||
# Add target compute capabilities to NVCC flags.
|
# Add target compute capabilities to NVCC flags.
|
||||||
for capability in compute_capabilities:
|
for capability in compute_capabilities:
|
||||||
num = capability[0] + capability[2]
|
num = capability[0] + capability[2]
|
||||||
@@ -196,6 +238,14 @@ if _is_cuda():
|
|||||||
NVCC_FLAGS += [
|
NVCC_FLAGS += [
|
||||||
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||||
]
|
]
|
||||||
|
if int(capability[0]) >= 8:
|
||||||
|
NVCC_FLAGS_PUNICA += [
|
||||||
|
"-gencode", f"arch=compute_{num},code=sm_{num}"
|
||||||
|
]
|
||||||
|
if capability.endswith("+PTX"):
|
||||||
|
NVCC_FLAGS_PUNICA += [
|
||||||
|
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||||
|
]
|
||||||
|
|
||||||
# Use NVCC threads to parallelize the build.
|
# Use NVCC threads to parallelize the build.
|
||||||
if nvcc_cuda_version >= Version("11.2"):
|
if nvcc_cuda_version >= Version("11.2"):
|
||||||
@@ -203,14 +253,52 @@ if _is_cuda():
|
|||||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||||
|
|
||||||
elif _is_hip():
|
if nvcc_cuda_version >= Version("11.8"):
|
||||||
amd_arch = get_amdgpu_offload_arch()
|
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
|
||||||
if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
|
||||||
f"amdgpu_arch_found: {amd_arch}")
|
|
||||||
|
|
||||||
ext_modules = []
|
# changes for punica kernels
|
||||||
|
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
|
||||||
|
REMOVE_NVCC_FLAGS = [
|
||||||
|
'-D__CUDA_NO_HALF_OPERATORS__',
|
||||||
|
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
||||||
|
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
|
||||||
|
'-D__CUDA_NO_HALF2_OPERATORS__',
|
||||||
|
]
|
||||||
|
for flag in REMOVE_NVCC_FLAGS:
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
|
||||||
|
|
||||||
|
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
for i in range(device_count):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
if major < 8:
|
||||||
|
install_punica = False
|
||||||
|
break
|
||||||
|
if install_punica:
|
||||||
|
ext_modules.append(
|
||||||
|
CUDAExtension(
|
||||||
|
name="vllm._punica_C",
|
||||||
|
sources=["csrc/punica/punica_ops.cc"] +
|
||||||
|
glob("csrc/punica/bgmv/*.cu"),
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS_PUNICA,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
elif _is_hip():
|
||||||
|
amd_archs = os.getenv("GPU_ARCHS")
|
||||||
|
if amd_archs is None:
|
||||||
|
amd_archs = get_amdgpu_offload_arch()
|
||||||
|
for arch in amd_archs.split(";"):
|
||||||
|
if arch not in ROCM_SUPPORTED_ARCHS:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
||||||
|
f"amdgpu_arch_found: {arch}")
|
||||||
|
NVCC_FLAGS += [f"--offload-arch={arch}"]
|
||||||
|
|
||||||
|
elif _is_neuron():
|
||||||
|
neuronxcc_version = get_neuronxcc_version()
|
||||||
|
|
||||||
vllm_extension_sources = [
|
vllm_extension_sources = [
|
||||||
"csrc/cache_kernels.cu",
|
"csrc/cache_kernels.cu",
|
||||||
@@ -221,21 +309,25 @@ vllm_extension_sources = [
|
|||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||||
"csrc/quantization/gptq/q_gemm.cu",
|
"csrc/quantization/gptq/q_gemm.cu",
|
||||||
"csrc/cuda_utils_kernels.cu",
|
"csrc/cuda_utils_kernels.cu",
|
||||||
|
"csrc/moe_align_block_size_kernels.cu",
|
||||||
"csrc/pybind.cpp",
|
"csrc/pybind.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||||
|
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
|
||||||
|
|
||||||
vllm_extension = CUDAExtension(
|
if not _is_neuron():
|
||||||
name="vllm._C",
|
vllm_extension = CUDAExtension(
|
||||||
sources=vllm_extension_sources,
|
name="vllm._C",
|
||||||
extra_compile_args={
|
sources=vllm_extension_sources,
|
||||||
"cxx": CXX_FLAGS,
|
extra_compile_args={
|
||||||
"nvcc": NVCC_FLAGS,
|
"cxx": CXX_FLAGS,
|
||||||
},
|
"nvcc": NVCC_FLAGS,
|
||||||
)
|
},
|
||||||
ext_modules.append(vllm_extension)
|
libraries=["cuda"] if _is_cuda() else [],
|
||||||
|
)
|
||||||
|
ext_modules.append(vllm_extension)
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
@@ -264,6 +356,12 @@ def get_vllm_version() -> str:
|
|||||||
if hipcc_version != MAIN_CUDA_VERSION:
|
if hipcc_version != MAIN_CUDA_VERSION:
|
||||||
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
||||||
version += f"+rocm{rocm_version_str}"
|
version += f"+rocm{rocm_version_str}"
|
||||||
|
elif _is_neuron():
|
||||||
|
# Get the Neuron version
|
||||||
|
neuron_version = str(neuronxcc_version)
|
||||||
|
if neuron_version != MAIN_CUDA_VERSION:
|
||||||
|
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||||
|
version += f"+neuron{neuron_version_str}"
|
||||||
else:
|
else:
|
||||||
cuda_version = str(nvcc_cuda_version)
|
cuda_version = str(nvcc_cuda_version)
|
||||||
if cuda_version != MAIN_CUDA_VERSION:
|
if cuda_version != MAIN_CUDA_VERSION:
|
||||||
@@ -287,12 +385,20 @@ def get_requirements() -> List[str]:
|
|||||||
if _is_hip():
|
if _is_hip():
|
||||||
with open(get_path("requirements-rocm.txt")) as f:
|
with open(get_path("requirements-rocm.txt")) as f:
|
||||||
requirements = f.read().strip().split("\n")
|
requirements = f.read().strip().split("\n")
|
||||||
|
elif _is_neuron():
|
||||||
|
with open(get_path("requirements-neuron.txt")) as f:
|
||||||
|
requirements = f.read().strip().split("\n")
|
||||||
else:
|
else:
|
||||||
with open(get_path("requirements.txt")) as f:
|
with open(get_path("requirements.txt")) as f:
|
||||||
requirements = f.read().strip().split("\n")
|
requirements = f.read().strip().split("\n")
|
||||||
return requirements
|
return requirements
|
||||||
|
|
||||||
|
|
||||||
|
package_data = {"vllm": ["py.typed"]}
|
||||||
|
if os.environ.get("VLLM_USE_PRECOMPILED"):
|
||||||
|
ext_modules = []
|
||||||
|
package_data["vllm"].append("*.so")
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="vllm",
|
name="vllm",
|
||||||
version=get_vllm_version(),
|
version=get_vllm_version(),
|
||||||
@@ -320,6 +426,6 @@ setuptools.setup(
|
|||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
|
||||||
package_data={"vllm": ["py.typed"]},
|
package_data=package_data,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,8 +29,13 @@ def api_server():
|
|||||||
script_path = Path(__file__).parent.joinpath(
|
script_path = Path(__file__).parent.joinpath(
|
||||||
"api_server_async_engine.py").absolute()
|
"api_server_async_engine.py").absolute()
|
||||||
uvicorn_process = subprocess.Popen([
|
uvicorn_process = subprocess.Popen([
|
||||||
sys.executable, "-u",
|
sys.executable,
|
||||||
str(script_path), "--model", "facebook/opt-125m"
|
"-u",
|
||||||
|
str(script_path),
|
||||||
|
"--model",
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
])
|
])
|
||||||
yield
|
yield
|
||||||
uvicorn_process.terminate()
|
uvicorn_process.terminate()
|
||||||
@@ -81,6 +86,9 @@ def test_api_server(api_server):
|
|||||||
pool.join()
|
pool.join()
|
||||||
|
|
||||||
# check cancellation stats
|
# check cancellation stats
|
||||||
|
# give it some times to update the stats
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
num_aborted_requests = requests.get(
|
num_aborted_requests = requests.get(
|
||||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||||
assert num_aborted_requests > 0
|
assert num_aborted_requests > 0
|
||||||
|
|||||||
@@ -25,6 +25,13 @@ class MockEngine:
|
|||||||
return [RequestOutput(
|
return [RequestOutput(
|
||||||
request_id=self.request_id)] if self.request_id else []
|
request_id=self.request_id)] if self.request_id else []
|
||||||
|
|
||||||
|
async def encode_request_async(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return [1]
|
||||||
|
|
||||||
def generate(self, request_id):
|
def generate(self, request_id):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|
||||||
@@ -35,6 +42,10 @@ class MockEngine:
|
|||||||
del kwargs # Unused
|
del kwargs # Unused
|
||||||
self.add_request_calls += 1
|
self.add_request_calls += 1
|
||||||
|
|
||||||
|
async def add_request_async(self, **kwargs):
|
||||||
|
del kwargs # Unused
|
||||||
|
self.add_request_calls += 1
|
||||||
|
|
||||||
def abort_request(self, request_id):
|
def abort_request(self, request_id):
|
||||||
del request_id # Unused
|
del request_id # Unused
|
||||||
self.abort_request_calls += 1
|
self.abort_request_calls += 1
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
from argparse import Namespace
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
from vllm.entrypoints.openai.api_server import *
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
|
||||||
|
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
||||||
|
__file__))).parent.parent / "examples/template_chatml.jinja"
|
||||||
|
assert chatml_jinja_path.exists()
|
||||||
|
|
||||||
# Define models, templates, and their corresponding expected outputs
|
# Define models, templates, and their corresponding expected outputs
|
||||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||||
@@ -12,8 +18,7 @@ MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
|||||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||||
("facebook/opt-125m", None, False,
|
("facebook/opt-125m", None, False,
|
||||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||||
("facebook/opt-125m", "../../examples/template_chatml.jinja", True,
|
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
|
||||||
"""<|im_start|>user
|
|
||||||
Hello<|im_end|>
|
Hello<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Hi there!<|im_end|>
|
Hi there!<|im_end|>
|
||||||
@@ -21,8 +26,7 @@ Hi there!<|im_end|>
|
|||||||
What is the capital of<|im_end|>
|
What is the capital of<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
"""),
|
"""),
|
||||||
("facebook/opt-125m", "../../examples/template_chatml.jinja", False,
|
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
|
||||||
"""<|im_start|>user
|
|
||||||
Hello<|im_end|>
|
Hello<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Hi there!<|im_end|>
|
Hi there!<|im_end|>
|
||||||
@@ -44,7 +48,6 @@ TEST_MESSAGES = [
|
|||||||
'content': 'What is the capital of'
|
'content': 'What is the capital of'
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -52,14 +55,17 @@ class MockTokenizer:
|
|||||||
chat_template = None
|
chat_template = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockServingChat:
|
||||||
|
tokenizer: MockTokenizer
|
||||||
|
|
||||||
|
|
||||||
def test_load_chat_template():
|
def test_load_chat_template():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "../../examples/template_chatml.jinja"
|
|
||||||
mock_args = Namespace(chat_template=template)
|
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
# Call the function with the mocked args
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
load_chat_template(mock_args, tokenizer)
|
chat_template=chatml_jinja_path)
|
||||||
|
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
@@ -73,11 +79,11 @@ def test_load_chat_template():
|
|||||||
def test_no_load_chat_template():
|
def test_no_load_chat_template():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "../../examples/does_not_exist"
|
template = "../../examples/does_not_exist"
|
||||||
mock_args = Namespace(chat_template=template)
|
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
|
|
||||||
# Call the function with the mocked args
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
load_chat_template(mock_args, tokenizer=tokenizer)
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
|
chat_template=template)
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
# Test assertions
|
# Test assertions
|
||||||
@@ -94,9 +100,9 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
expected_output):
|
expected_output):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
mock_args = Namespace(chat_template=template)
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
load_chat_template(mock_args, tokenizer)
|
chat_template=template)
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
@@ -112,8 +118,3 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
|
|
||||||
# Test assertion
|
# Test assertion
|
||||||
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
|
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
|
||||||
|
|
||||||
|
|
||||||
def test_health_endpoint():
|
|
||||||
response = client.get("/health")
|
|
||||||
assert response.status_code == 200
|
|
||||||
@@ -2,32 +2,20 @@
|
|||||||
|
|
||||||
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
||||||
"""
|
"""
|
||||||
from multiprocessing import Process, set_start_method
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import ray
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
|
||||||
from vllm.utils import get_open_port
|
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
|
broadcast_tensor_dict,
|
||||||
)
|
)
|
||||||
from vllm.worker.worker import _init_distributed_environment
|
from vllm.test_utils import (init_test_distributed_environment,
|
||||||
|
multi_process_tensor_parallel)
|
||||||
|
|
||||||
def init_test_distributed_environment(pipeline_parallel_size: int,
|
|
||||||
tensor_parallel_size: int, rank: int,
|
|
||||||
distributed_init_port: str):
|
|
||||||
parallel_config = ParallelConfig(pipeline_parallel_size,
|
|
||||||
tensor_parallel_size,
|
|
||||||
worker_use_ray=True)
|
|
||||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
_init_distributed_environment(parallel_config, rank,
|
|
||||||
distributed_init_method)
|
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||||
distributed_init_port: str):
|
distributed_init_port: str):
|
||||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||||
@@ -43,6 +31,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
|||||||
assert torch.allclose(t, expected)
|
assert torch.allclose(t, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||||
distributed_init_port: str):
|
distributed_init_port: str):
|
||||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||||
@@ -64,20 +53,40 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
|||||||
assert torch.allclose(t, expected)
|
assert torch.allclose(t, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
|
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||||
|
distributed_init_port: str):
|
||||||
|
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||||
|
distributed_init_port)
|
||||||
|
test_dict = {
|
||||||
|
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||||
|
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
|
||||||
|
"c": "test",
|
||||||
|
"d": [1, 2, 3],
|
||||||
|
"e": {
|
||||||
|
"a": 1,
|
||||||
|
"b": 2
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
broadcast_tensor_dict(test_dict, src=0)
|
||||||
|
else:
|
||||||
|
recv_dict = broadcast_tensor_dict(src=0)
|
||||||
|
assert len(recv_dict) == len(test_dict)
|
||||||
|
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||||
|
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||||
|
assert recv_dict["c"] == test_dict["c"]
|
||||||
|
assert recv_dict["d"] == test_dict["d"]
|
||||||
|
assert recv_dict["e"] == test_dict["e"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||||
@pytest.mark.parametrize("test_target",
|
@pytest.mark.parametrize("test_target", [
|
||||||
[all_reduce_test_worker, all_gather_test_worker])
|
all_reduce_test_worker, all_gather_test_worker,
|
||||||
|
broadcast_tensor_dict_test_worker
|
||||||
|
])
|
||||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||||
set_start_method("spawn", force=True)
|
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||||
distributed_init_port = get_open_port()
|
|
||||||
processes = []
|
|
||||||
for rank in range(tensor_parallel_size):
|
|
||||||
p = Process(target=test_target,
|
|
||||||
args=(tensor_parallel_size, rank, distributed_init_port))
|
|
||||||
p.start()
|
|
||||||
processes.append(p)
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
assert all(p.exitcode == 0 for p in processes)
|
|
||||||
|
|||||||
85
tests/distributed/test_custom_all_reduce.py
Normal file
85
tests/distributed/test_custom_all_reduce.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.test_utils import (init_test_distributed_environment,
|
||||||
|
multi_process_tensor_parallel)
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||||
|
for i, v in enumerate(test_sizes):
|
||||||
|
test_sizes[i] -= v % 8
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
|
def graph_allreduce(world_size, rank, distributed_init_port):
|
||||||
|
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
init_test_distributed_environment(1, world_size, rank,
|
||||||
|
distributed_init_port)
|
||||||
|
|
||||||
|
custom_ar.init_custom_ar()
|
||||||
|
for sz in test_sizes:
|
||||||
|
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
with custom_ar.capture():
|
||||||
|
# use integers so result matches NCCL exactly
|
||||||
|
inp1 = torch.randint(1,
|
||||||
|
16, (sz, ),
|
||||||
|
dtype=dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
inp2 = torch.randint(1,
|
||||||
|
16, (sz, ),
|
||||||
|
dtype=dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||||
|
# the input buffer is immediately modified to test
|
||||||
|
# synchronization
|
||||||
|
dist.all_reduce(inp1)
|
||||||
|
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||||
|
dist.all_reduce(inp2)
|
||||||
|
graph.replay()
|
||||||
|
assert torch.allclose(out1, inp1)
|
||||||
|
assert torch.allclose(out2, inp2)
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
|
def eager_allreduce(world_size, rank, distributed_init_port):
|
||||||
|
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
init_test_distributed_environment(1, world_size, rank,
|
||||||
|
distributed_init_port)
|
||||||
|
|
||||||
|
sz = 1024
|
||||||
|
custom_ar.init_custom_ar()
|
||||||
|
fa = custom_ar.get_handle()
|
||||||
|
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||||
|
out = fa.all_reduce_unreg(inp)
|
||||||
|
assert torch.allclose(out, inp * world_size)
|
||||||
|
|
||||||
|
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||||
|
out = fa.all_reduce_unreg(inp)
|
||||||
|
assert torch.allclose(out, inp * world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||||
|
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
|
||||||
|
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||||
|
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
multi_process_tensor_parallel(2, graph_allreduce)
|
||||||
254
tests/entrypoints/test_openai_server.py
Normal file
254
tests/entrypoints/test_openai_server.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
|
||||||
|
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class ServerRunner:
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PYTHONUNBUFFERED"] = "1"
|
||||||
|
self.proc = subprocess.Popen(
|
||||||
|
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
||||||
|
env=env,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
stderr=sys.stderr,
|
||||||
|
)
|
||||||
|
self._wait_for_server()
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _wait_for_server(self):
|
||||||
|
# run health check
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if requests.get(
|
||||||
|
"http://localhost:8000/health").status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception as err:
|
||||||
|
if self.proc.poll() is not None:
|
||||||
|
raise RuntimeError("Server exited unexpectedly.") from err
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Server failed to start in time.") from err
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if hasattr(self, "proc"):
|
||||||
|
self.proc.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def server():
|
||||||
|
ray.init()
|
||||||
|
server_runner = ServerRunner.remote([
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16", # use half precision for speed and memory savings in CI environment
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
])
|
||||||
|
ray.get(server_runner.ready.remote())
|
||||||
|
yield server_runner
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def client():
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
api_key="token-abc123",
|
||||||
|
)
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||||
|
completion = await client.completions.create(model=MODEL_NAME,
|
||||||
|
prompt="Hello, my name is",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 1
|
||||||
|
assert completion.choices[0].text is not None and len(
|
||||||
|
completion.choices[0].text) >= 5
|
||||||
|
assert completion.choices[0].finish_reason == "length"
|
||||||
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
|
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
assert completion.choices[0].text is not None and len(
|
||||||
|
completion.choices[0].text) >= 5
|
||||||
|
|
||||||
|
|
||||||
|
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
assert chat_completion.id is not None
|
||||||
|
assert chat_completion.choices is not None and len(
|
||||||
|
chat_completion.choices) == 1
|
||||||
|
assert chat_completion.choices[0].message is not None
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_completion_streaming(server, client: openai.AsyncOpenAI):
|
||||||
|
prompt = "What is an LLM?"
|
||||||
|
|
||||||
|
single_completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
single_output = single_completion.choices[0].text
|
||||||
|
single_usage = single_completion.usage
|
||||||
|
|
||||||
|
stream = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
chunks = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk.choices[0].text)
|
||||||
|
assert chunk.choices[0].finish_reason == "length"
|
||||||
|
assert chunk.usage == single_usage
|
||||||
|
assert "".join(chunks) == single_output
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
output = chat_completion.choices[0].message.content
|
||||||
|
stop_reason = chat_completion.choices[0].finish_reason
|
||||||
|
|
||||||
|
# test streaming
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
chunks = []
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.role:
|
||||||
|
assert delta.role == "assistant"
|
||||||
|
if delta.content:
|
||||||
|
chunks.append(delta.content)
|
||||||
|
assert chunk.choices[0].finish_reason == stop_reason
|
||||||
|
assert "".join(chunks) == output
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_completions(server, client: openai.AsyncOpenAI):
|
||||||
|
# test simple list
|
||||||
|
batch = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=["Hello, my name is", "Hello, my name is"],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
assert len(batch.choices) == 2
|
||||||
|
assert batch.choices[0].text == batch.choices[1].text
|
||||||
|
|
||||||
|
# test n = 2
|
||||||
|
batch = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=["Hello, my name is", "Hello, my name is"],
|
||||||
|
n=2,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body=dict(
|
||||||
|
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
|
||||||
|
use_beam_search=True),
|
||||||
|
)
|
||||||
|
assert len(batch.choices) == 4
|
||||||
|
assert batch.choices[0].text != batch.choices[
|
||||||
|
1].text, "beam search should be different"
|
||||||
|
assert batch.choices[0].text == batch.choices[
|
||||||
|
2].text, "two copies of the same prompt should be the same"
|
||||||
|
assert batch.choices[1].text == batch.choices[
|
||||||
|
3].text, "two copies of the same prompt should be the same"
|
||||||
|
|
||||||
|
# test streaming
|
||||||
|
batch = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=["Hello, my name is", "Hello, my name is"],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
texts = [""] * 2
|
||||||
|
async for chunk in batch:
|
||||||
|
assert len(chunk.choices) == 1
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
texts[choice.index] += choice.text
|
||||||
|
assert texts[0] == texts[1]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
@@ -1,44 +1,7 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
from vllm.utils import create_kv_caches_with_random
|
||||||
|
|
||||||
|
|
||||||
def create_kv_caches(
|
|
||||||
num_blocks: int,
|
|
||||||
block_size: int,
|
|
||||||
num_layers: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
||||||
torch.random.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
scale = head_size**-0.5
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
|
||||||
key_caches = []
|
|
||||||
for _ in range(num_layers):
|
|
||||||
key_cache = torch.empty(size=key_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
key_cache.uniform_(-scale, scale)
|
|
||||||
key_caches.append(key_cache)
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
|
||||||
value_caches = []
|
|
||||||
for _ in range(num_layers):
|
|
||||||
value_cache = torch.empty(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
value_cache.uniform_(-scale, scale)
|
|
||||||
value_caches.append(value_cache)
|
|
||||||
return key_caches, value_caches
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def kv_cache_factory():
|
def kv_cache_factory():
|
||||||
return create_kv_caches
|
return create_kv_caches_with_random
|
||||||
|
|||||||
@@ -6,14 +6,16 @@ import torch
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops, cache_ops
|
||||||
from vllm.utils import get_max_shared_memory_bytes
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
# This will change depending on the compute capability.
|
# This will change depending on the compute capability.
|
||||||
# - 512 as a buffer
|
# - 512 as a buffer
|
||||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||||
NUM_BLOCKS = 40000 # Arbitrary values for testing
|
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||||
|
# Reduce NUM_BLOCKS when it happens.
|
||||||
|
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
|||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
USE_ALIBI = [False, True]
|
USE_ALIBI = [False, True]
|
||||||
|
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||||
|
|
||||||
@@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention(
|
|||||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
def test_paged_attention(
|
def test_paged_attention(
|
||||||
@@ -116,6 +120,7 @@ def test_paged_attention(
|
|||||||
use_alibi: bool,
|
use_alibi: bool,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: int,
|
device: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -158,8 +163,9 @@ def test_paged_attention(
|
|||||||
|
|
||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||||
num_kv_heads, head_size, dtype,
|
num_kv_heads, head_size,
|
||||||
seed, gpu_id)
|
kv_cache_dtype, dtype, seed,
|
||||||
|
gpu_id)
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
@@ -177,6 +183,7 @@ def test_paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||||
@@ -209,11 +216,30 @@ def test_paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Unknown version: {version}")
|
raise AssertionError(f"Unknown version: {version}")
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
|
if kv_cache_dtype == "fp8_e5m2":
|
||||||
|
# Convert cache data back to dtype.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||||
|
block_size, x)
|
||||||
|
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=gpu_id)
|
||||||
|
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
|
||||||
|
key_cache = dequantized_key_cache
|
||||||
|
|
||||||
|
value_cache_shape = value_cache.shape
|
||||||
|
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=gpu_id)
|
||||||
|
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
|
||||||
|
value_cache = dequantized_value_cache
|
||||||
|
|
||||||
ref_output = torch.empty_like(query)
|
ref_output = torch.empty_like(query)
|
||||||
ref_single_query_cached_kv_attention(
|
ref_single_query_cached_kv_attention(
|
||||||
ref_output,
|
ref_output,
|
||||||
@@ -230,7 +256,12 @@ def test_paged_attention(
|
|||||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||||
# implementations, there is a small numerical difference in the two
|
# implementations, there is a small numerical difference in the two
|
||||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||||
|
# so we use a relaxed tolerance for the test.
|
||||||
|
atol, rtol = 1e-3, 1e-5
|
||||||
|
if kv_cache_dtype == "fp8_e5m2":
|
||||||
|
atol, rtol = 1e-2, 1e-5
|
||||||
|
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
|
|||||||
@@ -3,18 +3,22 @@ import random
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from vllm._C import cache_ops
|
from vllm._C import cache_ops
|
||||||
|
|
||||||
|
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [83] # Arbitrary values for testing
|
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||||
NUM_HEADS = [8] # Arbitrary values for testing
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
BLOCK_SIZES = [8, 16, 32]
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
|
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
|
||||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||||
|
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||||
@@ -26,6 +30,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_copy_blocks(
|
def test_copy_blocks(
|
||||||
kv_cache_factory,
|
kv_cache_factory,
|
||||||
@@ -38,6 +43,7 @@ def test_copy_blocks(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: int,
|
device: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
@@ -59,7 +65,8 @@ def test_copy_blocks(
|
|||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||||
num_layers, num_heads,
|
num_layers, num_heads,
|
||||||
head_size, dtype, seed, gpu_id)
|
head_size, kv_cache_dtype,
|
||||||
|
dtype, seed, gpu_id)
|
||||||
|
|
||||||
# Clone the KV caches.
|
# Clone the KV caches.
|
||||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||||
@@ -124,7 +131,7 @@ def test_reshape_and_cache(
|
|||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||||
num_heads, head_size, dtype,
|
num_heads, head_size, dtype,
|
||||||
seed, gpu_id)
|
None, seed, gpu_id)
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Clone the KV caches.
|
# Clone the KV caches.
|
||||||
@@ -133,7 +140,7 @@ def test_reshape_and_cache(
|
|||||||
|
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||||
slot_mapping)
|
slot_mapping, "auto")
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||||
@@ -149,3 +156,68 @@ def test_reshape_and_cache(
|
|||||||
|
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
assert torch.allclose(key_cache, cloned_key_cache)
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||||
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_swap_blocks(
|
||||||
|
kv_cache_factory,
|
||||||
|
direction: Tuple[str, str],
|
||||||
|
num_mappings: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: int,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
src_device = f"{direction[0]}:{device}" if direction[
|
||||||
|
0] == "cuda" else direction[0]
|
||||||
|
dst_device = f"{direction[1]}:{device}" if direction[
|
||||||
|
1] == "cuda" else direction[1]
|
||||||
|
|
||||||
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
|
# For the same device, mapping must not overlap
|
||||||
|
if src_device == dst_device:
|
||||||
|
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||||
|
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
||||||
|
else:
|
||||||
|
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
|
|
||||||
|
block_mapping = dict(zip(src_blocks, dst_blocks))
|
||||||
|
|
||||||
|
# Create the KV caches on the first device.
|
||||||
|
src_key_caches, src_value_caches = kv_cache_factory(
|
||||||
|
num_blocks, block_size, 1, num_heads, head_size, dtype, seed,
|
||||||
|
src_device)
|
||||||
|
|
||||||
|
# Create the KV caches on the second device.
|
||||||
|
dist_key_caches, dist_value_caches = kv_cache_factory(
|
||||||
|
num_blocks, block_size, 1, num_heads, head_size, dtype, seed,
|
||||||
|
dst_device)
|
||||||
|
|
||||||
|
src_key_caches_clone = src_key_caches[0].clone()
|
||||||
|
src_value_caches_clone = src_value_caches[0].clone()
|
||||||
|
|
||||||
|
# Call the swap_blocks kernel.
|
||||||
|
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||||
|
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||||
|
block_mapping)
|
||||||
|
|
||||||
|
for src, dst in block_mapping.items():
|
||||||
|
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||||
|
dist_key_caches[0][dst].cpu())
|
||||||
|
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||||
|
dist_value_caches[0][dst].cpu())
|
||||||
|
|||||||
50
tests/kernels/test_fused_moe.py
Normal file
50
tests/kernels/test_fused_moe.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
|
|
||||||
|
def torch_moe(a, w1, w2, topk_weight, topk_ids):
|
||||||
|
B, D = a.shape
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk_ids.shape[1],
|
||||||
|
w2.shape[1],
|
||||||
|
dtype=a.dtype,
|
||||||
|
device=a.device)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
topk_weight = topk_weight.view(-1)
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
out[mask] = SiluAndMul()(
|
||||||
|
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||||
|
return (out.view(B, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(B, -1, 1)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||||
|
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||||
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
|
@pytest.mark.parametrize("e", [8, 64])
|
||||||
|
@pytest.mark.parametrize("topk", [2, 6])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
def test_fused_moe(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||||
|
score = torch.softmax(score, dim=-1)
|
||||||
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
|
|
||||||
|
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
|
||||||
|
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||||
169
tests/kernels/test_prefix_prefill.py
Normal file
169
tests/kernels/test_prefix_prefill.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import random
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
||||||
|
context_attention_fwd)
|
||||||
|
from xformers import ops as xops
|
||||||
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||||
|
|
||||||
|
NUM_HEADS = [12]
|
||||||
|
HEAD_SIZES = [128]
|
||||||
|
DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_contexted_kv_attention(
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
MAX_SEQ_LEN = 1024
|
||||||
|
MAX_CTX_LEN = 1024
|
||||||
|
BS = 10
|
||||||
|
cache_size = 640
|
||||||
|
block_size = 32
|
||||||
|
max_block_per_request = 64
|
||||||
|
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||||
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||||
|
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||||
|
|
||||||
|
num_tokens = sum(subquery_lens)
|
||||||
|
query = torch.empty(num_tokens,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
query.uniform_(-1e-3, 1e-3)
|
||||||
|
output = torch.empty(num_tokens,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
|
||||||
|
kv = torch.empty(sum(seq_lens),
|
||||||
|
2,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
kv.uniform_(-1e-3, 1e-3)
|
||||||
|
key, value = kv.unbind(dim=1)
|
||||||
|
|
||||||
|
k_cache = torch.zeros(cache_size,
|
||||||
|
block_size,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
v_cache = torch.zeros(cache_size,
|
||||||
|
block_size,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
k = torch.zeros(sum(subquery_lens),
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
v = torch.zeros(sum(subquery_lens),
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
|
||||||
|
values = values[torch.randperm(cache_size)]
|
||||||
|
block_table = values[:BS * max_block_per_request].view(
|
||||||
|
BS, max_block_per_request)
|
||||||
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
|
||||||
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
|
||||||
|
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda'),
|
||||||
|
dim=0)
|
||||||
|
max_input_len = MAX_SEQ_LEN
|
||||||
|
# copy kv to cache
|
||||||
|
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cuda'),
|
||||||
|
dim=0)
|
||||||
|
for i in range(BS):
|
||||||
|
for j in range(subquery_lens[i]):
|
||||||
|
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||||
|
j])
|
||||||
|
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||||
|
b_ctx_len[i] + j])
|
||||||
|
cur_ctx = 0
|
||||||
|
block_id = 0
|
||||||
|
while cur_ctx < b_ctx_len[i]:
|
||||||
|
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||||
|
if cur_ctx + block_size > b_ctx_len[i]:
|
||||||
|
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||||
|
else:
|
||||||
|
end_loc = start_loc + block_size
|
||||||
|
start_slot = block_table[i, block_id] * block_size
|
||||||
|
end_slot = start_slot + end_loc - start_loc
|
||||||
|
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
||||||
|
key[start_loc:end_loc])
|
||||||
|
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
||||||
|
value[start_loc:end_loc])
|
||||||
|
cur_ctx += block_size
|
||||||
|
block_id += 1
|
||||||
|
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||||
|
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
|
||||||
|
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||||
|
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
v_cache = v_cache.view(-1, block_size, num_heads,
|
||||||
|
head_size).permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
|
# Warm up the Triton kernel by calling it once before actually measuring generation time
|
||||||
|
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
|
||||||
|
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
|
||||||
|
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
|
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
|
||||||
|
attn_op = xops.fmha.cutlass.FwOp()
|
||||||
|
|
||||||
|
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||||
|
subquery_lens, seq_lens)
|
||||||
|
output_ref = xops.memory_efficient_attention_forward(
|
||||||
|
query.unsqueeze(0),
|
||||||
|
key.unsqueeze(0),
|
||||||
|
value.unsqueeze(0),
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
p=0.0,
|
||||||
|
scale=scale,
|
||||||
|
op=attn_op,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
output_ref = xops.memory_efficient_attention_forward(
|
||||||
|
query.unsqueeze(0),
|
||||||
|
key.unsqueeze(0),
|
||||||
|
value.unsqueeze(0),
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
p=0.0,
|
||||||
|
scale=scale,
|
||||||
|
op=attn_op,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
|
output_ref = output_ref.squeeze(0)
|
||||||
|
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||||
0
tests/lora/__init__.py
Normal file
0
tests/lora/__init__.py
Normal file
143
tests/lora/conftest.py
Normal file
143
tests/lora/conftest.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import contextlib
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from collections import OrderedDict
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
destroy_model_parallel, initialize_model_parallel)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
destroy_model_parallel()
|
||||||
|
with contextlib.suppress(AssertionError):
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def cleanup_fixture():
|
||||||
|
yield
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dist_init():
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
temp_file = tempfile.mkstemp()[1]
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
init_method=f"file://{temp_file}",
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||||
|
initialize_model_parallel(1, 1)
|
||||||
|
yield
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dist_init_torch_only():
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return
|
||||||
|
temp_file = tempfile.mkstemp()[1]
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
init_method=f"file://{temp_file}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_model() -> nn.Module:
|
||||||
|
model = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(764, 100)),
|
||||||
|
("dense2", RowParallelLinear(100, 50)),
|
||||||
|
(
|
||||||
|
"layer1",
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(100, 10)),
|
||||||
|
("dense2", RowParallelLinear(10, 50)),
|
||||||
|
])),
|
||||||
|
),
|
||||||
|
("act2", nn.ReLU()),
|
||||||
|
("output", ColumnParallelLinear(50, 10)),
|
||||||
|
("outact", nn.Sigmoid()),
|
||||||
|
# Special handling for lm_head & sampler
|
||||||
|
("lm_head", ParallelLMHead(512, 10)),
|
||||||
|
("sampler", Sampler(512))
|
||||||
|
]))
|
||||||
|
model.config = MagicMock()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_model_gate_up() -> nn.Module:
|
||||||
|
model = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(764, 100)),
|
||||||
|
("dense2", RowParallelLinear(100, 50)),
|
||||||
|
(
|
||||||
|
"layer1",
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(100, 10)),
|
||||||
|
("dense2", RowParallelLinear(10, 50)),
|
||||||
|
])),
|
||||||
|
),
|
||||||
|
("act2", nn.ReLU()),
|
||||||
|
("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
|
||||||
|
("outact", nn.Sigmoid()),
|
||||||
|
# Special handling for lm_head & sampler
|
||||||
|
("lm_head", ParallelLMHead(512, 10)),
|
||||||
|
("sampler", Sampler(512))
|
||||||
|
]))
|
||||||
|
model.config = MagicMock()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sql_lora_files():
|
||||||
|
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||||
|
cleanup()
|
||||||
|
get_model_old = get_model
|
||||||
|
|
||||||
|
def get_model_patched(model_config, lora_config=None):
|
||||||
|
return get_model_old(model_config,
|
||||||
|
LoRAConfig(max_loras=4, max_lora_rank=8))
|
||||||
|
|
||||||
|
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||||
|
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||||
|
yield engine.llm_engine
|
||||||
|
del engine
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama_2_7b_model_extra_embeddings(
|
||||||
|
llama_2_7b_engine_extra_embeddings) -> nn.Module:
|
||||||
|
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
|
||||||
709
tests/lora/test_layers.py
Normal file
709
tests/lora/test_layers.py
Normal file
@@ -0,0 +1,709 @@
|
|||||||
|
import pytest
|
||||||
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.lora.layers import (
|
||||||
|
ColumnParallelLinearWithLoRA,
|
||||||
|
MergedColumnParallelLinearWithLoRA,
|
||||||
|
QKVParallelLinearWithLora,
|
||||||
|
VocabParallelEmbeddingWithLoRA,
|
||||||
|
RowParallelLinearWithLoRA,
|
||||||
|
SamplerWithLoRA,
|
||||||
|
LoRAMapping,
|
||||||
|
BaseLayerWithLoRA,
|
||||||
|
)
|
||||||
|
from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
QKVParallelLinear)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
|
||||||
|
from .utils import DummyLoRAManager
|
||||||
|
|
||||||
|
TOLERANCES = {
|
||||||
|
torch.float16: (5e-3, 5e-3),
|
||||||
|
torch.float32: (5e-3, 5e-3),
|
||||||
|
torch.bfloat16: (3e-2, 2e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_id_to_index(num_loras: int,
|
||||||
|
num_slots: int,
|
||||||
|
log: bool = True) -> List[Optional[int]]:
|
||||||
|
"""Creates a random lora_id_to_index mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_loras: The number of active loras in the mapping.
|
||||||
|
num_slots: The number of slots in the mapping. Must be larger
|
||||||
|
than num_loras.
|
||||||
|
log: Whether to log the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_loras > num_slots:
|
||||||
|
raise ValueError(
|
||||||
|
f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
|
||||||
|
"num_loras must be less than or equal to num_slots.")
|
||||||
|
|
||||||
|
slots: List[Optional[int]] = [None] * num_slots
|
||||||
|
random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
|
||||||
|
for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
|
||||||
|
slots[slot_idx] = lora_id
|
||||||
|
|
||||||
|
if log:
|
||||||
|
print(f"Created lora_id_to_index mapping: {slots}.")
|
||||||
|
|
||||||
|
return slots
|
||||||
|
|
||||||
|
|
||||||
|
def populate_loras(
|
||||||
|
id_to_index: List[Optional[int]],
|
||||||
|
layer: BaseLayerWithLoRA,
|
||||||
|
layer_weights: torch.Tensor,
|
||||||
|
generate_embeddings_tensor: int = 0,
|
||||||
|
repeats: int = 1,
|
||||||
|
) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
|
||||||
|
"""This method populates the lora layers with lora weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_to_index: a list of lora ids. The index of the lora id
|
||||||
|
represents which memory slot the lora matrices are
|
||||||
|
stored in. A None value indicates a free slot.
|
||||||
|
layer: the LoRAlayer to populate.
|
||||||
|
layer_weights: the PyTorch tensor containing the layer's
|
||||||
|
weights.
|
||||||
|
generate_embeddings_tensor: whether to generate an
|
||||||
|
embeddings tensor for each LoRA.
|
||||||
|
repeats: must only be set for column parallel packed
|
||||||
|
layers. Indicates the number of loras to compose
|
||||||
|
together to create a single lora layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dictionary that maps the lora ID to the
|
||||||
|
# corresponding lora weights.
|
||||||
|
lora_dict: Dict[int, LoRALayerWeights] = dict()
|
||||||
|
|
||||||
|
# Dictionary that maps the lora ID to the
|
||||||
|
# corresponding subloras. Only useful when
|
||||||
|
# repeats > 1.
|
||||||
|
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
|
||||||
|
|
||||||
|
for slot_idx, lora_id in enumerate(id_to_index):
|
||||||
|
if lora_id is not None:
|
||||||
|
subloras = []
|
||||||
|
sublora_len = layer_weights.shape[0] // repeats
|
||||||
|
for i in range(repeats):
|
||||||
|
sublora = DummyLoRAManager().init_random_lora(
|
||||||
|
module_name=f"fake_{i}",
|
||||||
|
weight=layer_weights,
|
||||||
|
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||||
|
)
|
||||||
|
sublora.lora_b = sublora.lora_b[:, (sublora_len *
|
||||||
|
i):(sublora_len * (i + 1))]
|
||||||
|
sublora.optimize()
|
||||||
|
subloras.append(sublora)
|
||||||
|
|
||||||
|
lora = PackedLoRALayerWeights.pack(
|
||||||
|
subloras) if repeats > 1 else subloras[0]
|
||||||
|
|
||||||
|
layer.set_lora(
|
||||||
|
slot_idx,
|
||||||
|
lora_a=lora.lora_a,
|
||||||
|
lora_b=lora.lora_b,
|
||||||
|
embeddings_tensor=lora.embeddings_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_dict[lora_id] = lora
|
||||||
|
sublora_dict[lora_id] = subloras
|
||||||
|
|
||||||
|
return lora_dict, sublora_dict
|
||||||
|
|
||||||
|
|
||||||
|
def create_random_inputs(
|
||||||
|
active_lora_ids: List[int],
|
||||||
|
num_inputs: int,
|
||||||
|
input_size: Tuple[int, ...],
|
||||||
|
input_range: Tuple[float, float],
|
||||||
|
input_type: torch.dtype = torch.int,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
|
||||||
|
"""Creates random inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
active_lora_ids: lora IDs of active lora weights.
|
||||||
|
num_inputs: the number of inputs to create.
|
||||||
|
input_size: the size of each individual input.
|
||||||
|
input_range: the range of values to include in the input.
|
||||||
|
input_range[0] <= possible input values < input_range[1]
|
||||||
|
input_type: the type of values in the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
low, high = input_range
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = [], [], []
|
||||||
|
for _ in range(num_inputs):
|
||||||
|
if input_type == torch.int:
|
||||||
|
inputs.append(
|
||||||
|
torch.randint(low=int(low),
|
||||||
|
high=int(high),
|
||||||
|
size=input_size,
|
||||||
|
device="cuda"))
|
||||||
|
else:
|
||||||
|
inputs.append(
|
||||||
|
torch.rand(size=input_size, dtype=input_type, device="cuda") *
|
||||||
|
high + low)
|
||||||
|
|
||||||
|
lora_id = random.choice(active_lora_ids)
|
||||||
|
index_mapping += [lora_id] * input_size[0]
|
||||||
|
prompt_mapping += [lora_id]
|
||||||
|
|
||||||
|
return inputs, index_mapping, prompt_mapping
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
def test_embeddings(dist_init, num_loras) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_embedding_layer():
|
||||||
|
embedding = VocabParallelEmbedding(512, 256)
|
||||||
|
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
||||||
|
embedding.weight.data[512:, :] = 0
|
||||||
|
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
||||||
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return embedding, lora_embedding
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
embedding, lora_embedding = create_random_embedding_layer()
|
||||||
|
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_embedding,
|
||||||
|
layer_weights=embedding.weight.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(inputs))
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = embedding(input_)
|
||||||
|
after_a = F.embedding(
|
||||||
|
input_,
|
||||||
|
lora.lora_a,
|
||||||
|
)
|
||||||
|
result += (after_a @ lora.lora_b)
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_embedding.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(inputs))
|
||||||
|
expected_result = embedding(torch.cat(inputs))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_embedding_layer():
|
||||||
|
embedding = VocabParallelEmbedding(512, 256)
|
||||||
|
embedding_data = torch.rand_like(embedding.weight.data)
|
||||||
|
embedding.weight.data = embedding_data
|
||||||
|
embedding.weight.data[512:, :] = 0
|
||||||
|
expanded_embedding = VocabParallelEmbedding(
|
||||||
|
512 + lora_config.lora_extra_vocab_size * max_loras,
|
||||||
|
256,
|
||||||
|
org_num_embeddings=512)
|
||||||
|
expanded_embedding.weight.data[:512, :] = embedding_data
|
||||||
|
# We need to deepcopy the embedding as it will be modifed
|
||||||
|
# in place
|
||||||
|
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
||||||
|
deepcopy(expanded_embedding))
|
||||||
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return expanded_embedding, lora_embedding
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
expanded_embedding, lora_embedding = create_random_embedding_layer()
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_embedding,
|
||||||
|
layer_weights=torch.zeros(
|
||||||
|
(256, 512 + lora_config.lora_extra_vocab_size)),
|
||||||
|
generate_embeddings_tensor=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All embeddings tensors have the same shape.
|
||||||
|
embeddings_tensors = [
|
||||||
|
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
|
||||||
|
]
|
||||||
|
embeddings_tensor_len = embeddings_tensors[0].shape[0]
|
||||||
|
|
||||||
|
# Add empty embeddings_tensors for unoccupied lora slots.
|
||||||
|
for _ in range(max_loras - len(embeddings_tensors)):
|
||||||
|
embeddings_tensors.append(
|
||||||
|
torch.zeros(embeddings_tensors[0].shape, device="cuda"))
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
original_inputs = deepcopy(inputs)
|
||||||
|
|
||||||
|
# Force some of the inputs to be in the extended embeddings range
|
||||||
|
# to guarantee that their behavior is tested.
|
||||||
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||||
|
prompt_mapping):
|
||||||
|
embedding_id = lora_id - 1
|
||||||
|
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
|
||||||
|
original_input_[-1] = 512
|
||||||
|
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
||||||
|
original_input_[-2] = 512 + embeddings_tensor_len - 1
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
expanded_embedding.weight[512:512 +
|
||||||
|
(embeddings_tensor_len *
|
||||||
|
max_loras)] = torch.cat(embeddings_tensors)
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||||
|
prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = expanded_embedding(input_)
|
||||||
|
after_a = F.embedding(
|
||||||
|
original_input_,
|
||||||
|
lora.lora_a,
|
||||||
|
)
|
||||||
|
result += (after_a @ lora.lora_b)
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_embedding.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
original_inputs = deepcopy(inputs)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||||
|
expected_result = expanded_embedding(torch.cat(inputs))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
def test_lm_head_sampler(dist_init, num_loras) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_sampler_layer():
|
||||||
|
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||||
|
1024, 32000)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
linear.weight.data[:, 32000:] = 0
|
||||||
|
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||||
|
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
|
||||||
|
linear.weight.device)
|
||||||
|
lora_sampler.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return linear, sampler, lora_sampler
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
linear, sampler, lora_sampler = create_random_sampler_layer()
|
||||||
|
|
||||||
|
# NOTE: all the generated loras share the same embeddings tensor.
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_sampler,
|
||||||
|
layer_weights=linear.weight,
|
||||||
|
generate_embeddings_tensor=1024,
|
||||||
|
)
|
||||||
|
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
|
||||||
|
embeddings_tensor_len = embeddings_tensor.shape[0]
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=8 * num_loras, # * 3,
|
||||||
|
input_size=(1, 1024),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
input_ = torch.rand(20, 1024, device="cuda")
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
32000,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_sampler.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||||
|
embedding=linear.weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
|
||||||
|
original_weight = linear.weight.clone()
|
||||||
|
|
||||||
|
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
|
||||||
|
embeddings_tensor_len] = embeddings_tensor
|
||||||
|
|
||||||
|
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = sampler._get_logits(hidden_states=input_,
|
||||||
|
embedding=linear.weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||||
|
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
sampler.org_vocab_size = 32000
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_sampler.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=8 * num_loras * 3,
|
||||||
|
input_size=(1, 1024),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
32000,
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
|
lora_sampler.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||||
|
embedding=original_weight,
|
||||||
|
embedding_bias=None)[:, :32000]
|
||||||
|
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||||
|
embedding=original_weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||||
|
def test_linear_parallel(dist_init, num_loras, orientation) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_linear_parallel_layer():
|
||||||
|
if orientation == "row":
|
||||||
|
linear = RowParallelLinear(4096, 4096, bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = RowParallelLinearWithLoRA(linear)
|
||||||
|
else:
|
||||||
|
linear = ColumnParallelLinear(4096, 4096, bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
||||||
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return linear, lora_linear
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
linear, lora_linear = create_random_linear_parallel_layer()
|
||||||
|
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_linear,
|
||||||
|
layer_weights=linear.weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_linear.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = linear(input_)[0]
|
||||||
|
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_linear.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_linear.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
@pytest.mark.parametrize("repeats", [2, 3])
|
||||||
|
def test_column_parallel_packed(dist_init, num_loras, repeats) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_column_parallel_packed_layer():
|
||||||
|
if repeats == 2:
|
||||||
|
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
|
||||||
|
bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||||
|
else:
|
||||||
|
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = QKVParallelLinearWithLora(linear)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeConfig:
|
||||||
|
hidden_size = 4096
|
||||||
|
num_key_value_heads = 32
|
||||||
|
num_attention_heads = 32
|
||||||
|
|
||||||
|
lora_linear.create_lora_weights(max_loras,
|
||||||
|
lora_config,
|
||||||
|
model_config=FakeConfig())
|
||||||
|
|
||||||
|
return linear, lora_linear
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
|
||||||
|
linear, lora_linear = create_column_parallel_packed_layer()
|
||||||
|
|
||||||
|
lora_dict, sublora_dict = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_linear,
|
||||||
|
layer_weights=linear.weight,
|
||||||
|
repeats=repeats,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_linear.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
result = linear(input_)[0]
|
||||||
|
subloras = sublora_dict[lora_id]
|
||||||
|
for i, sublora in enumerate(subloras):
|
||||||
|
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * (
|
||||||
|
i + 1
|
||||||
|
)] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_linear.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_linear.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
144
tests/lora/test_llama.py
Normal file
144
tests/lora/test_llama.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from .conftest import cleanup
|
||||||
|
|
||||||
|
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm, lora_path: str, lora_id: int):
|
||||||
|
prompts = [
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]"
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop=["[/assistant]"])
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
|
def test_llama_lora(sql_lora_files, tp_size):
|
||||||
|
# Cannot use as it will initialize torch.cuda too early...
|
||||||
|
# if torch.cuda.device_count() < tp_size:
|
||||||
|
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||||
|
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=tp_size)
|
||||||
|
|
||||||
|
expected_no_lora_output = [
|
||||||
|
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]",
|
||||||
|
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ",
|
||||||
|
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
|
||||||
|
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
|
||||||
|
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
|
||||||
|
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
|
||||||
|
]
|
||||||
|
expected_lora_output = [
|
||||||
|
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
|
||||||
|
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
|
||||||
|
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ",
|
||||||
|
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ",
|
||||||
|
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
|
||||||
|
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' "
|
||||||
|
]
|
||||||
|
|
||||||
|
print("lora adapter created")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
|
||||||
|
|
||||||
|
print("lora 1")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output
|
||||||
|
|
||||||
|
print("no lora")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
|
||||||
|
|
||||||
|
print("lora 2")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output
|
||||||
|
|
||||||
|
print("removing lora")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("Requires multiple GPUs")
|
||||||
|
def test_llama_tensor_parallel_equality(sql_lora_files):
|
||||||
|
# Cannot use as it will initialize torch.cuda too early...
|
||||||
|
# if torch.cuda.device_count() < 4:
|
||||||
|
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||||
|
|
||||||
|
llm_tp1 = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=1)
|
||||||
|
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
|
||||||
|
|
||||||
|
del llm_tp1
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
llm_tp2 = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=2)
|
||||||
|
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
|
||||||
|
|
||||||
|
del llm_tp2
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
assert output_tp1 == output_tp2
|
||||||
|
|
||||||
|
llm_tp4 = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=4)
|
||||||
|
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
|
||||||
|
|
||||||
|
del llm_tp4
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
assert output_tp1 == output_tp4
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_lora_warmup(sql_lora_files):
|
||||||
|
"""Test that the LLM initialization works with a warmup LORA path and is more conservative"""
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
def get_num_gpu_blocks_lora():
|
||||||
|
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
|
||||||
|
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
|
||||||
|
return num_gpu_blocks_lora_warmup
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
def get_num_gpu_blocks_no_lora():
|
||||||
|
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
|
||||||
|
num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
|
||||||
|
return num_gpu_blocks_no_lora_warmup
|
||||||
|
|
||||||
|
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
|
||||||
|
num_gpu_blocks_no_lora_warmup = ray.get(
|
||||||
|
get_num_gpu_blocks_no_lora.remote())
|
||||||
|
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
|
||||||
|
"The warmup with lora should be more"
|
||||||
|
" conservative than without lora, therefore the number of memory blocks for the KV cache should be "
|
||||||
|
"less when using lora than when not using lora")
|
||||||
224
tests/lora/test_lora.py
Normal file
224
tests/lora/test_lora.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
|
||||||
|
|
||||||
|
from .utils import DummyLoRAManager
|
||||||
|
|
||||||
|
TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
|
||||||
|
QKV_TENSOR_SIZES = [
|
||||||
|
(8192, 1024, 1024),
|
||||||
|
(8192 // 8, 1024 // 8, 1024 // 8),
|
||||||
|
(4096, 4096, 4096),
|
||||||
|
(4096 // 2, 4096 // 2, 4096 // 2),
|
||||||
|
]
|
||||||
|
BATCH_SIZES = [8, 32, 256]
|
||||||
|
RANKS = [8]
|
||||||
|
DTYPES = [torch.float16]
|
||||||
|
TOLERANCES = {
|
||||||
|
torch.float16: (5e-3, 5e-3),
|
||||||
|
torch.bfloat16: (3e-2, 2e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("rank", RANKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_apply_lora(m, n, k, rank, dtype) -> None:
|
||||||
|
manager = DummyLoRAManager()
|
||||||
|
|
||||||
|
module_name = "module"
|
||||||
|
weight = torch.rand([m, n], device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
manager.init_random_lora(module_name, weight, rank=rank)
|
||||||
|
lora = manager.get_module_lora(module_name)
|
||||||
|
|
||||||
|
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||||
|
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
|
|
||||||
|
lora_a_stack = torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora.lora_a.shape[1],
|
||||||
|
lora.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
lora_b_stack = torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora.lora_b.shape[1],
|
||||||
|
lora.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
for i in range(lora_a_stack.shape[0]):
|
||||||
|
lora_a_stack[i][0] = lora.lora_a.T
|
||||||
|
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T
|
||||||
|
|
||||||
|
output = torch.zeros(k, m, device="cuda", dtype=dtype)
|
||||||
|
_apply_lora(
|
||||||
|
input, lora_a_stack, lora_b_stack,
|
||||||
|
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
|
||||||
|
output)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[dtype]
|
||||||
|
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
output[:] = 0
|
||||||
|
_apply_lora(input, lora_a_stack, lora_b_stack,
|
||||||
|
torch.full((len(input), ), -1, device="cuda"), output)
|
||||||
|
assert torch.allclose(torch.zeros_like(output), output)
|
||||||
|
|
||||||
|
manager.reset_lora()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("rank", RANKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
|
||||||
|
if m % 2 != 0:
|
||||||
|
pytest.skip("m must be divisible by 2")
|
||||||
|
if m // 2 not in TENSOR_SIZES:
|
||||||
|
pytest.skip("m//2 must be in TENSOR_SIZES")
|
||||||
|
|
||||||
|
manager = DummyLoRAManager()
|
||||||
|
|
||||||
|
module_name = "module"
|
||||||
|
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
manager.init_random_lora(module_name + "1", weight, rank=rank)
|
||||||
|
lora_1 = manager.get_module_lora(module_name + "1")
|
||||||
|
manager.init_random_lora(module_name + "2", weight, rank=rank)
|
||||||
|
lora_2 = manager.get_module_lora(module_name + "2")
|
||||||
|
|
||||||
|
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||||
|
expected = torch.cat([
|
||||||
|
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
|
||||||
|
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
lora_a_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_1.lora_a.shape[1],
|
||||||
|
lora_1.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
lora_b_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_1.lora_b.shape[1],
|
||||||
|
lora_1.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
for i in range(lora_a_stacks[0].shape[0]):
|
||||||
|
lora_a_stacks[0][i][0] = lora_1.lora_a.T
|
||||||
|
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
|
||||||
|
lora_a_stacks[1][i][0] = lora_2.lora_a.T
|
||||||
|
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T
|
||||||
|
|
||||||
|
output = torch.zeros(k, m, device="cuda", dtype=dtype)
|
||||||
|
_apply_lora_packed_nslice(
|
||||||
|
input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.randint(0,
|
||||||
|
lora_a_stacks[0].shape[0], (len(input), ),
|
||||||
|
device="cuda"), output, (m // 2, m // 2))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[dtype]
|
||||||
|
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
output[:] = 0
|
||||||
|
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.full((len(input), ), -1, device="cuda"),
|
||||||
|
output, (m // 2, m // 2))
|
||||||
|
assert torch.allclose(torch.zeros_like(output), output)
|
||||||
|
|
||||||
|
manager.reset_lora()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("rank", RANKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
|
||||||
|
manager = DummyLoRAManager()
|
||||||
|
|
||||||
|
module_name = "module"
|
||||||
|
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
|
||||||
|
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
manager.init_random_lora(module_name + "q", weight_q, rank=rank)
|
||||||
|
lora_q = manager.get_module_lora(module_name + "q")
|
||||||
|
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
|
||||||
|
lora_k = manager.get_module_lora(module_name + "k")
|
||||||
|
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
|
||||||
|
lora_v = manager.get_module_lora(module_name + "v")
|
||||||
|
|
||||||
|
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||||
|
expected = torch.cat([
|
||||||
|
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
|
||||||
|
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
|
||||||
|
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
lora_a_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_q.lora_a.shape[1],
|
||||||
|
lora_q.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
] + [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_k.lora_a.shape[1],
|
||||||
|
lora_k.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
lora_b_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_q.lora_b.shape[1],
|
||||||
|
lora_q.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
] + [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_k.lora_b.shape[1],
|
||||||
|
lora_k.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
for i in range(lora_a_stacks[0].shape[0]):
|
||||||
|
lora_a_stacks[0][i][0] = lora_q.lora_a.T
|
||||||
|
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
|
||||||
|
lora_a_stacks[1][i][0] = lora_k.lora_a.T
|
||||||
|
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
|
||||||
|
lora_a_stacks[2][i][0] = lora_v.lora_a.T
|
||||||
|
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T
|
||||||
|
|
||||||
|
output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
|
||||||
|
_apply_lora_packed_nslice(
|
||||||
|
input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.randint(0,
|
||||||
|
lora_a_stacks[0].shape[0], (len(input), ),
|
||||||
|
device="cuda"), output, (qkv[0], qkv[1], qkv[2]))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[dtype]
|
||||||
|
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
output[:] = 0
|
||||||
|
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.full((len(input), ), -1, device="cuda"),
|
||||||
|
output, (qkv[0], qkv[1], qkv[2]))
|
||||||
|
assert torch.allclose(torch.zeros_like(output), output)
|
||||||
|
|
||||||
|
manager.reset_lora()
|
||||||
475
tests/lora/test_lora_manager.py
Normal file
475
tests/lora/test_lora_manager.py
Normal file
@@ -0,0 +1,475 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||||
|
RowParallelLinearWithLoRA,
|
||||||
|
MergedColumnParallelLinearWithLoRA)
|
||||||
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
|
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
|
||||||
|
LRUCacheLoRAModelManager, LoRAMapping)
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||||
|
WorkerLoRAManager)
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_lora_tensors(sql_lora_files):
|
||||||
|
tensors = load_file(
|
||||||
|
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||||
|
new_embeddings = load_file(
|
||||||
|
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||||
|
lora_model = LoRAModel.from_lora_tensors(1,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
tensors,
|
||||||
|
"cuda",
|
||||||
|
embeddings=new_embeddings)
|
||||||
|
for module_name, lora in lora_model.loras.items():
|
||||||
|
assert lora.module_name == module_name
|
||||||
|
assert lora.rank == 8
|
||||||
|
assert lora.lora_alpha == 16
|
||||||
|
assert lora.lora_a is not None
|
||||||
|
assert lora.lora_b is not None
|
||||||
|
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
||||||
|
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||||
|
assert lora.lora_a.shape[1] == 8
|
||||||
|
embeddings_module = next(
|
||||||
|
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
||||||
|
if embeddings_module:
|
||||||
|
assert torch.equal(
|
||||||
|
lora.embeddings_tensor,
|
||||||
|
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
|
||||||
|
device=lora.embeddings_tensor.device))
|
||||||
|
else:
|
||||||
|
assert lora.embeddings_tensor is None
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora(lora_id: int, model: nn.Module,
|
||||||
|
sub_modules: List[str]) -> LoRAModel:
|
||||||
|
loras = {}
|
||||||
|
for name in sub_modules:
|
||||||
|
w = model.get_submodule(name).weight
|
||||||
|
loras[name] = LoRALayerWeights(
|
||||||
|
name,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
torch.rand([w.shape[1], 8], device="cuda"),
|
||||||
|
torch.rand([8, w.shape[0]], device="cuda"),
|
||||||
|
)
|
||||||
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
|
|
||||||
|
def create_packed_lora(
|
||||||
|
lora_id: int,
|
||||||
|
model: nn.Module,
|
||||||
|
module_name,
|
||||||
|
replaced_module_names,
|
||||||
|
empty_replaced_module_name=None,
|
||||||
|
) -> LoRAModel:
|
||||||
|
w = model.get_submodule(module_name).weight
|
||||||
|
loras = {}
|
||||||
|
for replaced_module_name in replaced_module_names:
|
||||||
|
if replaced_module_name == empty_replaced_module_name:
|
||||||
|
continue
|
||||||
|
loras[replaced_module_name] = LoRALayerWeights(
|
||||||
|
replaced_module_name,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
torch.rand([w.shape[1], 8], device="cuda"),
|
||||||
|
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
||||||
|
device="cuda"),
|
||||||
|
)
|
||||||
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_submodules(dist_init, dummy_model):
|
||||||
|
model = dummy_model
|
||||||
|
manager = LoRAModelManager(model,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=8,
|
||||||
|
max_loras=8),
|
||||||
|
lora_target_modules=["dense1", "layer1.dense2"])
|
||||||
|
model = manager.model
|
||||||
|
|
||||||
|
assert isinstance(model.get_submodule("dense1"),
|
||||||
|
ColumnParallelLinearWithLoRA)
|
||||||
|
assert isinstance(model.get_submodule("layer1.dense1"),
|
||||||
|
ColumnParallelLinearWithLoRA)
|
||||||
|
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
|
||||||
|
assert isinstance(model.get_submodule("layer1.dense2"),
|
||||||
|
RowParallelLinearWithLoRA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_model_manager(dist_init, dummy_model):
|
||||||
|
model = dummy_model
|
||||||
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
manager = LoRAModelManager(
|
||||||
|
model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||||
|
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert not manager.add_lora(model_lora1)
|
||||||
|
assert not manager.activate_lora(1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert not manager.add_lora(model_lora2)
|
||||||
|
assert not manager.activate_lora(2)
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert not manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.remove_lora(model_lora1.id)
|
||||||
|
assert not manager.remove_lora(model_lora1.id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||||
|
model = dummy_model
|
||||||
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
manager = LRUCacheLoRAModelManager(
|
||||||
|
model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||||
|
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert not manager.add_lora(model_lora1)
|
||||||
|
assert not manager.activate_lora(1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert not manager.add_lora(model_lora2)
|
||||||
|
assert not manager.activate_lora(2)
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert not manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.remove_lora(model_lora1.id)
|
||||||
|
assert not manager.remove_lora(model_lora1.id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.deactivate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||||
|
# This tests just the LRU cache functionality, everything else is
|
||||||
|
# tested in test_lora_model_manager
|
||||||
|
model = dummy_model
|
||||||
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
manager = LRUCacheLoRAModelManager(
|
||||||
|
model, 2, 2, 2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||||
|
["dense1", "dense2", "lm_head"])
|
||||||
|
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
# Add up to capacity
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {1, 2}
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
# Add over capacity
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.add_lora(model_lora4)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.activate_lora(4)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {3, 4}
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
# Add 3 again to move it to the top and then add 2
|
||||||
|
# should return false since it's in already
|
||||||
|
assert not manager.add_lora(model_lora3)
|
||||||
|
assert not manager.activate_lora(3)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {3, 2}
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
# Remove manually
|
||||||
|
assert manager.remove_lora(3)
|
||||||
|
assert not manager.remove_lora(3)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {2}
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.add_lora(model_lora4)
|
||||||
|
assert manager.activate_lora(4)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {3, 4}
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
assert manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == {4}
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
assert manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == set()
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
assert not manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == set()
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||||
|
sql_lora_files):
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
|
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
|
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||||
|
torch.device("cuda"))
|
||||||
|
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||||
|
|
||||||
|
mapping = LoRAMapping([], [])
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("3", 3, sql_lora_files),
|
||||||
|
LoRARequest("4", 4, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files),
|
||||||
|
LoRARequest("5", 5, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("6", 6, sql_lora_files),
|
||||||
|
LoRARequest("7", 7, sql_lora_files),
|
||||||
|
LoRARequest("8", 8, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6
|
||||||
|
|
||||||
|
# Over capacity
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("10", 10, sql_lora_files),
|
||||||
|
LoRARequest("11", 11, sql_lora_files),
|
||||||
|
LoRARequest("12", 12, sql_lora_files),
|
||||||
|
LoRARequest("13", 13, sql_lora_files),
|
||||||
|
LoRARequest("14", 14, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||||
|
sql_lora_files):
|
||||||
|
# Should remove every LoRA not specified in the request.
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
|
worker_lora_manager = WorkerLoRAManager(
|
||||||
|
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||||
|
torch.device("cuda"))
|
||||||
|
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||||
|
|
||||||
|
mapping = LoRAMapping([], [])
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("3", 3, sql_lora_files),
|
||||||
|
LoRARequest("4", 4, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 3, 4}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files),
|
||||||
|
LoRARequest("5", 5, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 5}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("6", 6, sql_lora_files),
|
||||||
|
LoRARequest("7", 7, sql_lora_files),
|
||||||
|
LoRARequest("8", 8, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {6, 7, 8}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7
|
||||||
|
|
||||||
|
# Over capacity
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("10", 10, sql_lora_files),
|
||||||
|
LoRARequest("11", 11, sql_lora_files),
|
||||||
|
LoRARequest("12", 12, sql_lora_files),
|
||||||
|
LoRARequest("13", 13, sql_lora_files),
|
||||||
|
LoRARequest("14", 14, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||||
|
model = dummy_model_gate_up
|
||||||
|
model_lora = create_packed_lora(
|
||||||
|
1,
|
||||||
|
model,
|
||||||
|
module_name="gate_up_proj",
|
||||||
|
replaced_module_names=["gate_proj", "up_proj"])
|
||||||
|
model_lora1 = create_packed_lora(
|
||||||
|
2,
|
||||||
|
model,
|
||||||
|
module_name="gate_up_proj",
|
||||||
|
replaced_module_names=["gate_proj", "up_proj"],
|
||||||
|
empty_replaced_module_name="gate_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = LoRAModelManager(
|
||||||
|
model, 2, 2, 2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||||
|
["gate_up_proj"])
|
||||||
|
model = manager.model
|
||||||
|
|
||||||
|
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||||
|
MergedColumnParallelLinearWithLoRA)
|
||||||
|
assert manager.add_lora(model_lora)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
|
||||||
|
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||||
|
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||||
|
|
||||||
|
assert torch.allclose(packed_lora.lora_a[0],
|
||||||
|
model_lora.get_lora("gate_proj").lora_a)
|
||||||
|
assert torch.allclose(packed_lora.lora_b[0],
|
||||||
|
model_lora.get_lora("gate_proj").lora_b)
|
||||||
|
assert torch.allclose(packed_lora.lora_a[1],
|
||||||
|
model_lora.get_lora("up_proj").lora_a)
|
||||||
|
assert torch.allclose(packed_lora.lora_b[1],
|
||||||
|
model_lora.get_lora("up_proj").lora_b)
|
||||||
|
|
||||||
|
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||||
|
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
||||||
|
|
||||||
|
assert packed_lora1.lora_a[0] is None
|
||||||
|
assert packed_lora1.lora_b[0] is None
|
||||||
|
assert torch.allclose(packed_lora1.lora_a[1],
|
||||||
|
model_lora1.get_lora("up_proj").lora_a)
|
||||||
|
assert torch.allclose(packed_lora1.lora_b[1],
|
||||||
|
model_lora1.get_lora("up_proj").lora_b)
|
||||||
175
tests/lora/test_punica.py
Normal file
175
tests/lora/test_punica.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# Based on code from https://github.com/punica-ai/punica
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.lora.punica as punica
|
||||||
|
|
||||||
|
|
||||||
|
def assert_close(a, b):
|
||||||
|
rtol, atol = {
|
||||||
|
torch.float16: (5e-3, 5e-3),
|
||||||
|
torch.bfloat16: (3e-2, 2e-2),
|
||||||
|
torch.float32: (None, None),
|
||||||
|
}[a.dtype]
|
||||||
|
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def _lora_ref_impl(
|
||||||
|
y_final: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_T_all: torch.Tensor,
|
||||||
|
wb_T_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
y_stage_1 = torch.empty(
|
||||||
|
(x.size(0), wa_T_all.size(-2)),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
bs = x.shape[0]
|
||||||
|
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
|
||||||
|
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
|
||||||
|
xi = x[i].unsqueeze(0).to(torch.float32)
|
||||||
|
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||||
|
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||||
|
|
||||||
|
tmp = xi @ wa
|
||||||
|
y_stage_1[i] = tmp.squeeze(0)
|
||||||
|
y_final[i] += (tmp @ wb).squeeze(0) * s
|
||||||
|
return y_final, y_stage_1
|
||||||
|
|
||||||
|
|
||||||
|
H1 = H2 = [
|
||||||
|
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
|
||||||
|
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
|
||||||
|
32256, 32512, 32768, 33024
|
||||||
|
]
|
||||||
|
SEED = [0xabcdabcd987]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||||
|
@pytest.mark.parametrize("h1", H1)
|
||||||
|
@pytest.mark.parametrize("h2", H2)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_lora_correctness(dtype_str, h1, h2, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
num_loras = 4
|
||||||
|
num_layers = 1
|
||||||
|
r = 8
|
||||||
|
bs = 32
|
||||||
|
scale = 0.123
|
||||||
|
dtype = getattr(torch, dtype_str)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
wa_T_all = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||||
|
y = torch.randn(bs, h2, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
y_ref = y.clone()
|
||||||
|
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
|
||||||
|
|
||||||
|
y_our = y.clone()
|
||||||
|
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
|
||||||
|
scale)
|
||||||
|
|
||||||
|
assert_close(y_ref, y_our)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||||
|
@pytest.mark.parametrize("h1", H1)
|
||||||
|
@pytest.mark.parametrize("h2", H2)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_lora_correctness_slice(dtype_str, h1, h2, seed):
|
||||||
|
if h2 % 3 != 0 or h2 // 3 not in H1:
|
||||||
|
pytest.skip("h2 must be divisible by 3 and in supported shapes")
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
num_loras = 4
|
||||||
|
num_layers = 1
|
||||||
|
r = 8
|
||||||
|
bs = 32
|
||||||
|
scale = 0.123
|
||||||
|
dtype = getattr(torch, dtype_str)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
wa_T_all_0 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wa_T_all_1 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wa_T_all_2 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all_0 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2 // 3,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all_1 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2 // 3,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all_2 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2 // 3,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||||
|
y = torch.randn(bs, h2, dtype=dtype, device=device)
|
||||||
|
s = h2 // 3
|
||||||
|
|
||||||
|
y_ref = y.clone()
|
||||||
|
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
|
||||||
|
layer_idx, scale)
|
||||||
|
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
|
||||||
|
layer_idx, scale)
|
||||||
|
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
|
||||||
|
layer_idx, scale)
|
||||||
|
|
||||||
|
y_our = y.clone()
|
||||||
|
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
|
||||||
|
layer_idx, scale, 0, s)
|
||||||
|
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
|
||||||
|
layer_idx, scale, s, s)
|
||||||
|
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
|
||||||
|
layer_idx, scale, s * 2, s)
|
||||||
|
|
||||||
|
assert_close(y_ref[:, :s], y_our[:, :s])
|
||||||
|
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
|
||||||
|
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])
|
||||||
69
tests/lora/test_tokenizer.py
Normal file
69
tests/lora/test_tokenizer.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transformers_tokenizer():
|
||||||
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
tokenizer = TokenizerGroup(
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
||||||
|
request_id="request_id", prompt="prompt", lora_request=None)
|
||||||
|
assert reference_tokenizer.encode(
|
||||||
|
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=None)
|
||||||
|
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
None) == await tokenizer.get_lora_tokenizer_async(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transformers_tokenizer_lora(sql_lora_files):
|
||||||
|
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
||||||
|
tokenizer = TokenizerGroup(
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||||
|
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
||||||
|
request_id="request_id", prompt="prompt", lora_request=lora_request)
|
||||||
|
assert reference_tokenizer.encode(
|
||||||
|
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=lora_request)
|
||||||
|
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
None) == await tokenizer.get_lora_tokenizer_async(None)
|
||||||
|
|
||||||
|
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
lora_request) != tokenizer.get_lora_tokenizer(None)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
|
||||||
|
lora_request = None
|
||||||
|
tokenizer = get_lora_tokenizer(lora_request)
|
||||||
|
assert not tokenizer
|
||||||
|
|
||||||
|
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||||
|
tokenizer = get_lora_tokenizer(lora_request)
|
||||||
|
assert tokenizer.get_added_vocab()
|
||||||
|
|
||||||
|
lora_request = LoRARequest("1", 1, str(tmpdir))
|
||||||
|
tokenizer = get_lora_tokenizer(lora_request)
|
||||||
|
assert not tokenizer
|
||||||
172
tests/lora/test_utils.py
Normal file
172
tests/lora/test_utils.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.utils import LRUCache
|
||||||
|
from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_fine_tuned_lora_name():
|
||||||
|
fixture = {
|
||||||
|
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
|
||||||
|
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
|
||||||
|
(
|
||||||
|
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||||
|
"model.embed_tokens",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||||
|
"model.embed_tokens",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
|
"model.layers.9.mlp.down_proj",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
|
"model.layers.9.mlp.down_proj",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for name, module_name, is_lora_a in fixture:
|
||||||
|
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_submodule():
|
||||||
|
model = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", nn.Linear(764, 100)),
|
||||||
|
("act1", nn.ReLU()),
|
||||||
|
("dense2", nn.Linear(100, 50)),
|
||||||
|
(
|
||||||
|
"seq1",
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", nn.Linear(100, 10)),
|
||||||
|
("dense2", nn.Linear(10, 50)),
|
||||||
|
])),
|
||||||
|
),
|
||||||
|
("act2", nn.ReLU()),
|
||||||
|
("output", nn.Linear(50, 10)),
|
||||||
|
("outact", nn.Sigmoid()),
|
||||||
|
]))
|
||||||
|
|
||||||
|
sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
replace_submodule(model, "act1", sigmoid)
|
||||||
|
assert dict(model.named_modules())["act1"] == sigmoid
|
||||||
|
|
||||||
|
dense2 = nn.Linear(1, 5)
|
||||||
|
replace_submodule(model, "seq1.dense2", dense2)
|
||||||
|
assert dict(model.named_modules())["seq1.dense2"] == dense2
|
||||||
|
|
||||||
|
|
||||||
|
class TestLRUCache(LRUCache):
|
||||||
|
|
||||||
|
def _on_remove(self, key, value):
|
||||||
|
if not hasattr(self, "_remove_counter"):
|
||||||
|
self._remove_counter = 0
|
||||||
|
self._remove_counter += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache():
|
||||||
|
cache = TestLRUCache(3)
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(2, 2)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache.put(3, 3)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache.put(4, 4)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
assert cache.get(2) == 2
|
||||||
|
|
||||||
|
cache.put(5, 5)
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
assert cache.pop(5) == 5
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.get(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.put(6, 6)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
|
|
||||||
|
cache.remove_oldest()
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 6}
|
||||||
|
assert cache._remove_counter == 4
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
assert len(cache) == 0
|
||||||
|
assert cache._remove_counter == 6
|
||||||
|
|
||||||
|
cache._remove_counter = 0
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[2] = 2
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache[3] = 3
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache[4] = 4
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
assert cache[2] == 2
|
||||||
|
|
||||||
|
cache[5] = 5
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
del cache[5]
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache[6] = 6
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
61
tests/lora/test_worker.py
Normal file
61
tests/lora/test_worker.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from vllm.lora.models import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"RANK": "0"})
|
||||||
|
def test_worker_apply_lora(sql_lora_files):
|
||||||
|
worker = Worker(
|
||||||
|
model_config=ModelConfig(
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
download_dir=None,
|
||||||
|
load_format="dummy",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
),
|
||||||
|
parallel_config=ParallelConfig(1, 1, False),
|
||||||
|
scheduler_config=SchedulerConfig(32, 32, 32, 256),
|
||||||
|
local_rank=0,
|
||||||
|
rank=0,
|
||||||
|
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
||||||
|
max_loras=32),
|
||||||
|
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
|
||||||
|
)
|
||||||
|
worker.init_model()
|
||||||
|
worker.load_model()
|
||||||
|
|
||||||
|
worker.model_runner.set_active_loras([], LoRAMapping([], []))
|
||||||
|
assert worker.list_loras() == set()
|
||||||
|
|
||||||
|
n_loras = 32
|
||||||
|
lora_requests = [
|
||||||
|
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
|
||||||
|
]
|
||||||
|
|
||||||
|
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
|
||||||
|
assert worker.list_loras() == {
|
||||||
|
lora_request.lora_int_id
|
||||||
|
for lora_request in lora_requests
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(32):
|
||||||
|
random.seed(i)
|
||||||
|
iter_lora_requests = random.choices(lora_requests,
|
||||||
|
k=random.randint(1, n_loras))
|
||||||
|
random.shuffle(iter_lora_requests)
|
||||||
|
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
|
||||||
|
worker.model_runner.set_active_loras(iter_lora_requests,
|
||||||
|
LoRAMapping([], []))
|
||||||
|
assert worker.list_loras().issuperset(
|
||||||
|
{lora_request.lora_int_id
|
||||||
|
for lora_request in iter_lora_requests})
|
||||||
88
tests/lora/utils.py
Normal file
88
tests/lora/utils.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLoRAManager:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._loras = {}
|
||||||
|
|
||||||
|
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
||||||
|
self._loras[module_name] = lora
|
||||||
|
|
||||||
|
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
||||||
|
return self._loras.get(module_name, None)
|
||||||
|
|
||||||
|
def init_random_lora(self,
|
||||||
|
module_name: str,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
rank: int = 8,
|
||||||
|
generate_embeddings_tensor: int = 0):
|
||||||
|
lora = LoRALayerWeights(
|
||||||
|
module_name,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_a=torch.rand([weight.shape[1], rank],
|
||||||
|
dtype=weight.dtype,
|
||||||
|
device="cuda"),
|
||||||
|
lora_b=torch.rand([rank, weight.shape[0]],
|
||||||
|
dtype=weight.dtype,
|
||||||
|
device="cuda"),
|
||||||
|
)
|
||||||
|
if generate_embeddings_tensor:
|
||||||
|
lora.embeddings_tensor = torch.rand(5,
|
||||||
|
generate_embeddings_tensor,
|
||||||
|
dtype=weight.dtype,
|
||||||
|
device="cuda")
|
||||||
|
self.set_module_lora(module_name, lora)
|
||||||
|
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def init_lora(self,
|
||||||
|
module_name: str,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
rank=8,
|
||||||
|
noop=False,
|
||||||
|
embeddings_tensor=None):
|
||||||
|
lora = LoRALayerWeights(
|
||||||
|
module_name,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_a=torch.rand([input_dim, rank], device="cuda"),
|
||||||
|
lora_b=torch.rand([rank, output_dim], device="cuda"),
|
||||||
|
embeddings_tensor=embeddings_tensor,
|
||||||
|
)
|
||||||
|
self.set_module_lora(module_name, lora)
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def reset_lora(self):
|
||||||
|
self._loras = {}
|
||||||
|
|
||||||
|
def init_packed_lora(
|
||||||
|
self,
|
||||||
|
module_name: str,
|
||||||
|
input_dim: int,
|
||||||
|
output_dims: List[int],
|
||||||
|
noop_lora_index: List[int] = None,
|
||||||
|
rank=8,
|
||||||
|
):
|
||||||
|
base_loras = []
|
||||||
|
noop_lora_index = set(noop_lora_index or [])
|
||||||
|
|
||||||
|
for i, out_dim in enumerate(output_dims):
|
||||||
|
base_lora = self.init_lora(
|
||||||
|
module_name + "_000_" + str(i),
|
||||||
|
input_dim,
|
||||||
|
out_dim,
|
||||||
|
rank=rank,
|
||||||
|
noop=i in noop_lora_index,
|
||||||
|
)
|
||||||
|
base_loras.append(base_lora)
|
||||||
|
packed_lora = PackedLoRALayerWeights.pack(base_loras)
|
||||||
|
self.set_module_lora(module_name, packed_lora)
|
||||||
|
return packed_lora
|
||||||
@@ -5,18 +5,11 @@ Run `pytest tests/models/test_models.py --forked`.
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m", "meta-llama/Llama-2-7b-hf",
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2",
|
||||||
"mistralai/Mistral-7B-v0.1",
|
"bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b",
|
||||||
"Deci/DeciLM-7b",
|
"EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b",
|
||||||
"tiiuae/falcon-7b",
|
"microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"
|
||||||
"gpt2",
|
|
||||||
"bigcode/tiny_starcoder_py",
|
|
||||||
"EleutherAI/gpt-j-6b",
|
|
||||||
"EleutherAI/pythia-70m",
|
|
||||||
"bigscience/bloom-560m",
|
|
||||||
"mosaicml/mpt-7b",
|
|
||||||
"microsoft/phi-2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
41
tests/prefix_caching/test_prefix_caching.py
Normal file
41
tests/prefix_caching/test_prefix_caching.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Compare the with and without prefix caching.
|
||||||
|
|
||||||
|
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
prefix = (
|
||||||
|
"You are an expert school principal, skilled in effectively managing "
|
||||||
|
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||||
|
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||||
|
"community, joyful discovery, and life-long learning. The candidate is "
|
||||||
|
"coming in for a first-round panel interview for a 8th grade Math "
|
||||||
|
"teaching role. They have 5 years of previous teaching experience "
|
||||||
|
"as an assistant teacher at a co-ed, public school with experience "
|
||||||
|
"in middle school math teaching. Based on these information, fulfill "
|
||||||
|
"the following paragraph: ")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [16])
|
||||||
|
def test_prefix_caching(
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
):
|
||||||
|
llm = LLM(model=model)
|
||||||
|
# -1 since the last token can change when concatenating prompts.
|
||||||
|
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
||||||
|
prompts = [prefix + prompt for prompt in example_prompts]
|
||||||
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
|
outputs_without_prefix = llm.generate(prompts, sampling_params)
|
||||||
|
outputs_with_prefix = llm.generate(prompts,
|
||||||
|
sampling_params,
|
||||||
|
prefix_pos=[prefix_pos] * len(prompts))
|
||||||
|
for output_without_prefix, output_with_prefix in zip(
|
||||||
|
outputs_without_prefix, outputs_with_prefix):
|
||||||
|
assert (output_without_prefix.outputs[0].token_ids ==
|
||||||
|
output_with_prefix.outputs[0].token_ids)
|
||||||
|
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
|
||||||
@@ -30,6 +30,7 @@ def test_get_prompt_logprobs(
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
vllm_results = vllm_model.model.generate(
|
vllm_results = vllm_model.model.generate(
|
||||||
example_prompts, sampling_params=vllm_sampling_params)
|
example_prompts, sampling_params=vllm_sampling_params)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
# Test whether logprobs are included in the results.
|
# Test whether logprobs are included in the results.
|
||||||
for result in vllm_results:
|
for result in vllm_results:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user