Compare commits
161 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2fb71ec9f | ||
|
|
f936657eb6 | ||
|
|
6f88f762bf | ||
|
|
202351d5bf | ||
|
|
2e8e49fce3 | ||
|
|
a8e98aee0c | ||
|
|
bb1ba58f06 | ||
|
|
7bedab5748 | ||
|
|
20f7cc4cde | ||
|
|
649aa730c5 | ||
|
|
a19bc5c628 | ||
|
|
28e616c4e3 | ||
|
|
30e775281d | ||
|
|
21877b0d75 | ||
|
|
cf5cb1e33e | ||
|
|
03ffd0a022 | ||
|
|
a425bd9a9a | ||
|
|
bbbf86565f | ||
|
|
9f6be8692e | ||
|
|
f187877945 | ||
|
|
947b794146 | ||
|
|
8d926e91f1 | ||
|
|
4ee52bb169 | ||
|
|
7d7e3b78a3 | ||
|
|
f98b745a81 | ||
|
|
2d1e86f1b1 | ||
|
|
1ac4ccf73c | ||
|
|
2ac4d5e2bf | ||
|
|
3302f0aef3 | ||
|
|
6f2dd6c37e | ||
|
|
bc0644574c | ||
|
|
400b8289f7 | ||
|
|
c1026311b5 | ||
|
|
2b1c116b5a | ||
|
|
cc796b1358 | ||
|
|
f029ef94d7 | ||
|
|
95592fa00a | ||
|
|
fbe66e1d0b | ||
|
|
90979c38f8 | ||
|
|
e21d7687a9 | ||
|
|
ff36139ffc | ||
|
|
e3e79e9e8a | ||
|
|
b9fe4616f9 | ||
|
|
64ca424e75 | ||
|
|
b5f93d0631 | ||
|
|
a58936966f | ||
|
|
dd54a4b026 | ||
|
|
eda1a7cad3 | ||
|
|
f04908cae7 | ||
|
|
ab019eea75 | ||
|
|
9841d48a10 | ||
|
|
3272d7a0b7 | ||
|
|
0bb1e885a0 | ||
|
|
d6545ad22e | ||
|
|
90eb3f43ca | ||
|
|
e67b4f2c2a | ||
|
|
d6770d1f23 | ||
|
|
b9cecc2635 | ||
|
|
898285c9bf | ||
|
|
a62de9ecfd | ||
|
|
4042d192f5 | ||
|
|
1117aa1411 | ||
|
|
080438477f | ||
|
|
4b5bcf8906 | ||
|
|
852ef5b4f5 | ||
|
|
db09d4ad83 | ||
|
|
c957c741d9 | ||
|
|
c07ece5ca4 | ||
|
|
7a9c20c715 | ||
|
|
005ba458b5 | ||
|
|
320a622ec4 | ||
|
|
c9927c1a6a | ||
|
|
fbd80ad409 | ||
|
|
22379d5513 | ||
|
|
1696725879 | ||
|
|
002800f081 | ||
|
|
e15932bb60 | ||
|
|
ce741ba3e4 | ||
|
|
bf87484efa | ||
|
|
8ce9c50d40 | ||
|
|
32b6816e55 | ||
|
|
c128d69856 | ||
|
|
55b28b1eee | ||
|
|
e11222333f | ||
|
|
28873a2799 | ||
|
|
0080d8329d | ||
|
|
0d93f15694 | ||
|
|
becd7a56f1 | ||
|
|
75471386de | ||
|
|
d2b2eed67c | ||
|
|
4b6f069b6f | ||
|
|
791d79de32 | ||
|
|
94d2f59895 | ||
|
|
75c0ca9d43 | ||
|
|
2a4ec90854 | ||
|
|
85ebcda94d | ||
|
|
d64bf1646c | ||
|
|
a41c20435e | ||
|
|
eedac9dba0 | ||
|
|
14f9c72bfd | ||
|
|
ad5f2fe34c | ||
|
|
4f8584756d | ||
|
|
65fc1c3127 | ||
|
|
c393af6cd7 | ||
|
|
0c04ce3234 | ||
|
|
73b3de79ea | ||
|
|
d1744376ae | ||
|
|
805de738f6 | ||
|
|
1b151ed181 | ||
|
|
e06f504a76 | ||
|
|
462ae5220a | ||
|
|
66c54aa9c3 | ||
|
|
735ecfff61 | ||
|
|
a57d13cc96 | ||
|
|
79af7e96a0 | ||
|
|
621980bdc0 | ||
|
|
aa84c92ef6 | ||
|
|
f7389f4763 | ||
|
|
55fe8a81ec | ||
|
|
e8ddc08ec8 | ||
|
|
1b0bd0fe8a | ||
|
|
20044cab7a | ||
|
|
64f23c2900 | ||
|
|
d4c7755ca8 | ||
|
|
aa39e42c5a | ||
|
|
953f28cf9a | ||
|
|
c0d00f5be6 | ||
|
|
58a072be15 | ||
|
|
82ad323dee | ||
|
|
df5dd3c68e | ||
|
|
2d867b55fa | ||
|
|
d7a1c6d614 | ||
|
|
7d5a155e4a | ||
|
|
1dde34e0f8 | ||
|
|
6fc2a38b11 | ||
|
|
c487a221ee | ||
|
|
9925c17940 | ||
|
|
8c4b2592fb | ||
|
|
cf21a9bd5c | ||
|
|
16c3e295a8 | ||
|
|
bda41c70dd | ||
|
|
453bafb96f | ||
|
|
328d231c17 | ||
|
|
b4b195b360 | ||
|
|
20b0d88d16 | ||
|
|
2bdea7ac11 | ||
|
|
58df2883cb | ||
|
|
6d7d95a70a | ||
|
|
96853af5a8 | ||
|
|
dbed69058c | ||
|
|
7b6ae94059 | ||
|
|
c6dfc3cdbe | ||
|
|
51be365143 | ||
|
|
c894836108 | ||
|
|
75beba29b5 | ||
|
|
ddfdf470ae | ||
|
|
b6fbb9a565 | ||
|
|
2179e4f4c5 | ||
|
|
a945fcc2ae | ||
|
|
be54f8e5c4 | ||
|
|
b396cb4998 |
101
.github/workflows/publish.yml
vendored
Normal file
101
.github/workflows/publish.yml
vendored
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# This workflow will upload a Python Package to Release asset
|
||||||
|
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
|
name: Create Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
|
||||||
|
# Needed to create release and upload assets
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
# Retrieve tag and create release
|
||||||
|
name: Create Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Extract branch info
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
id: create_release
|
||||||
|
uses: "actions/github-script@v6"
|
||||||
|
env:
|
||||||
|
RELEASE_TAG: ${{ env.release_tag }}
|
||||||
|
with:
|
||||||
|
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
script: |
|
||||||
|
const script = require('.github/workflows/scripts/create_release.js')
|
||||||
|
await script(github, context, core)
|
||||||
|
|
||||||
|
wheel:
|
||||||
|
name: Build Wheel
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
needs: release
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: ['ubuntu-20.04']
|
||||||
|
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||||
|
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Linux Env
|
||||||
|
if: ${{ runner.os == 'Linux' }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/env.sh
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||||
|
|
||||||
|
- name: Install PyTorch-cu${{ matrix.cuda-version }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||||
|
|
||||||
|
- name: Build wheel
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||||
|
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
|
||||||
|
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||||
|
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||||
|
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Upload Release Asset
|
||||||
|
uses: actions/upload-release-asset@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
upload_url: ${{ needs.release.outputs.upload_url }}
|
||||||
|
asset_path: ./dist/${{ env.wheel_name }}
|
||||||
|
asset_name: ${{ env.asset_name }}
|
||||||
|
asset_content_type: application/*
|
||||||
|
|
||||||
|
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||||
|
# - name: Publish package
|
||||||
|
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||||
|
# with:
|
||||||
|
# repository-url: https://test.pypi.org/legacy/
|
||||||
|
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
# skip-existing: true
|
||||||
15
.github/workflows/scripts/build.sh
vendored
Normal file
15
.github/workflows/scripts/build.sh
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python_executable=python$1
|
||||||
|
cuda_home=/usr/local/cuda-$2
|
||||||
|
|
||||||
|
# Update paths
|
||||||
|
PATH=${cuda_home}/bin:$PATH
|
||||||
|
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# Install requirements
|
||||||
|
$python_executable -m pip install wheel packaging
|
||||||
|
$python_executable -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Build
|
||||||
|
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||||
20
.github/workflows/scripts/create_release.js
vendored
Normal file
20
.github/workflows/scripts/create_release.js
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Uses Github's API to create the release and wait for result.
|
||||||
|
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||||
|
|
||||||
|
module.exports = async (github, context, core) => {
|
||||||
|
try {
|
||||||
|
const response = await github.rest.repos.createRelease({
|
||||||
|
draft: false,
|
||||||
|
generate_release_notes: true,
|
||||||
|
name: process.env.RELEASE_TAG,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
prerelease: false,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
tag_name: process.env.RELEASE_TAG,
|
||||||
|
});
|
||||||
|
|
||||||
|
core.setOutput('upload_url', response.data.upload_url);
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(error.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Replace '.' with '-' ex: 11.8 -> 11-8
|
||||||
|
cuda_version=$(echo $1 | tr "." "-")
|
||||||
|
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
|
||||||
|
OS=$(echo $2 | tr -d ".\-")
|
||||||
|
|
||||||
|
# Installs CUDA
|
||||||
|
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
rm cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt -qq update
|
||||||
|
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
|
||||||
|
sudo apt clean
|
||||||
|
|
||||||
|
# Test nvcc
|
||||||
|
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||||
|
nvcc --version
|
||||||
56
.github/workflows/scripts/env.sh
vendored
Normal file
56
.github/workflows/scripts/env.sh
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This file installs common linux environment tools
|
||||||
|
|
||||||
|
export LANG C.UTF-8
|
||||||
|
|
||||||
|
# python_version=$1
|
||||||
|
|
||||||
|
sudo apt-get update && \
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
software-properties-common \
|
||||||
|
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
apt-utils \
|
||||||
|
ca-certificates \
|
||||||
|
wget \
|
||||||
|
git \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
curl \
|
||||||
|
unzip \
|
||||||
|
unrar \
|
||||||
|
cmake \
|
||||||
|
net-tools \
|
||||||
|
sudo \
|
||||||
|
autotools-dev \
|
||||||
|
rsync \
|
||||||
|
jq \
|
||||||
|
openssh-server \
|
||||||
|
tmux \
|
||||||
|
screen \
|
||||||
|
htop \
|
||||||
|
pdsh \
|
||||||
|
openssh-client \
|
||||||
|
lshw \
|
||||||
|
dmidecode \
|
||||||
|
util-linux \
|
||||||
|
automake \
|
||||||
|
autoconf \
|
||||||
|
libtool \
|
||||||
|
net-tools \
|
||||||
|
pciutils \
|
||||||
|
libpci-dev \
|
||||||
|
libaio-dev \
|
||||||
|
libcap2 \
|
||||||
|
libtinfo5 \
|
||||||
|
fakeroot \
|
||||||
|
devscripts \
|
||||||
|
debhelper \
|
||||||
|
nfs-common
|
||||||
|
|
||||||
|
# Remove github bloat files to free up disk space
|
||||||
|
sudo rm -rf "/usr/local/share/boost"
|
||||||
|
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||||
|
sudo rm -rf "/usr/share/dotnet"
|
||||||
14
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
14
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python_executable=python$1
|
||||||
|
cuda_version=$2
|
||||||
|
|
||||||
|
# Install torch
|
||||||
|
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
||||||
|
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
|
||||||
|
|
||||||
|
# Print version information
|
||||||
|
$python_executable --version
|
||||||
|
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
|
||||||
|
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||||
|
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -173,3 +173,7 @@ cython_debug/
|
|||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
_build/
|
_build/
|
||||||
|
|
||||||
|
# vim swap files
|
||||||
|
*.swo
|
||||||
|
*.swp
|
||||||
|
|||||||
42
README.md
42
README.md
@@ -10,13 +10,26 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
|
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
**The First vLLM Bay Area Meetup (Oct 5th 6pm-8pm PT)**
|
||||||
|
|
||||||
|
We are excited to invite you to the first vLLM meetup!
|
||||||
|
The vLLM team will share recent updates and roadmap.
|
||||||
|
We will also have vLLM users and contributors coming up to the stage to share their experiences.
|
||||||
|
Please register [here](https://lu.ma/first-vllm-meetup) and join us!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||||
|
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||||
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
|
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
@@ -33,21 +46,28 @@ vLLM is fast with:
|
|||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
- Seamless integration with popular HuggingFace models
|
- Seamless integration with popular Hugging Face models
|
||||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||||
- Tensor parallelism support for distributed inference
|
- Tensor parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
|
|
||||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||||
|
|
||||||
|
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||||
|
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||||
|
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||||
|
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||||
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||||
|
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
|
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|
||||||
@@ -64,7 +84,7 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
|||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
vLLM outperforms Hugging Face Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
||||||
For details, check out our [blog post](https://vllm.ai).
|
For details, check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -96,3 +116,15 @@ For details, check out our [blog post](https://vllm.ai).
|
|||||||
|
|
||||||
We welcome and value any contributions and collaborations.
|
We welcome and value any contributions and collaborations.
|
||||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{kwon2023efficient,
|
||||||
|
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||||
|
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||||
|
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@@ -18,9 +18,11 @@ def main(args: argparse.Namespace):
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
max_num_seqs=args.batch_size,
|
max_num_seqs=args.batch_size,
|
||||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@@ -62,17 +64,28 @@ def main(args: argparse.Namespace):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Benchmark the latency of processing a single batch of '
|
description='Benchmark the latency of processing a single batch of '
|
||||||
'requests till completion.')
|
'requests till completion.')
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
parser.add_argument('--n', type=int, default=1,
|
parser.add_argument('--n',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help='Number of generated sequences per prompt.')
|
help='Number of generated sequences per prompt.')
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
parser.add_argument('--num-iters', type=int, default=3,
|
parser.add_argument('--num-iters',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
help='Number of iterations to run.')
|
help='Number of iterations to run.')
|
||||||
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def main(args: argparse.Namespace):
|
|||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
api_url = f"http://{args.host}:{args.port}/generate"
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
tokenizer = get_tokenizer(args.tokenizer)
|
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
benchmark_start_time = time.time()
|
benchmark_start_time = time.time()
|
||||||
@@ -227,5 +227,7 @@ if __name__ == "__main__":
|
|||||||
"Otherwise, we use Poisson process to synthesize "
|
"Otherwise, we use Poisson process to synthesize "
|
||||||
"the request arrival times.")
|
"the request arrival times.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument('--trust-remote-code', action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
@@ -22,15 +22,10 @@ def sample_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
data for data in dataset
|
|
||||||
if len(data["conversations"]) >= 2
|
|
||||||
]
|
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
@@ -63,16 +58,20 @@ def run_vllm(
|
|||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
|
quantization: Optional[str],
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
n: int,
|
n: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
quantization=quantization,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@@ -106,9 +105,11 @@ def run_hf(
|
|||||||
n: int,
|
n: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
assert not use_beam_search
|
assert not use_beam_search
|
||||||
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||||
if llm.config.model_type == "llama":
|
if llm.config.model_type == "llama":
|
||||||
# To enable padding in the HF backend.
|
# To enable padding in the HF backend.
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@@ -128,13 +129,14 @@ def run_hf(
|
|||||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||||
# Check if we can add more requests to the batch.
|
# Check if we can add more requests to the batch.
|
||||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||||
if (max(max_prompt_len, next_prompt_len) + max(
|
if (max(max_prompt_len, next_prompt_len) +
|
||||||
max_output_len, next_output_len)) <= 2048:
|
max(max_output_len, next_output_len)) <= 2048:
|
||||||
# We can add more requests to the batch.
|
# We can add more requests to the batch.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Generate the sequences.
|
# Generate the sequences.
|
||||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
input_ids = tokenizer(batch, return_tensors="pt",
|
||||||
|
padding=True).input_ids
|
||||||
llm_outputs = llm.generate(
|
llm_outputs = llm.generate(
|
||||||
input_ids=input_ids.cuda(),
|
input_ids=input_ids.cuda(),
|
||||||
do_sample=not use_beam_search,
|
do_sample=not use_beam_search,
|
||||||
@@ -161,44 +163,62 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = get_tokenizer(args.tokenizer)
|
tokenizer = get_tokenizer(args.tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(
|
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
args.quantization, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search)
|
args.seed, args.n, args.use_beam_search,
|
||||||
|
args.trust_remote_code)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
args.use_beam_search, args.hf_max_batch_size)
|
args.use_beam_search, args.hf_max_batch_size,
|
||||||
|
args.trust_remote_code)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
prompt_len + output_len
|
for _, prompt_len, output_len in requests)
|
||||||
for _, prompt_len, output_len in requests
|
|
||||||
)
|
|
||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
parser.add_argument("--backend",
|
||||||
|
type=str,
|
||||||
|
choices=["vllm", "hf"],
|
||||||
default="vllm")
|
default="vllm")
|
||||||
parser.add_argument("--dataset", type=str, required=True,
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n", type=int, default=1,
|
parser.add_argument("--n",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help="Number of generated sequences per prompt.")
|
help="Number of generated sequences per prompt.")
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
help="Number of prompts to process.")
|
help="Number of prompts to process.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
parser.add_argument("--hf-max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
help="Maximum batch size for HF backend.")
|
help="Maximum batch size for HF backend.")
|
||||||
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
@@ -207,6 +227,8 @@ if __name__ == "__main__":
|
|||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
if args.hf_max_batch_size is None:
|
if args.hf_max_batch_size is None:
|
||||||
raise ValueError("HF max batch size is required for HF backend.")
|
raise ValueError("HF max batch size is required for HF backend.")
|
||||||
|
if args.quantization is not None:
|
||||||
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,25 @@ void silu_and_mul(
|
|||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"silu_and_mul",
|
"silu_and_mul",
|
||||||
&silu_and_mul,
|
&silu_and_mul,
|
||||||
"Activation function used in SwiGLU.");
|
"Activation function used in SwiGLU.");
|
||||||
|
m.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
m.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@@ -34,9 +36,7 @@ void silu_and_mul(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(d, 1024));
|
dim3 block(std::min(d, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
"silu_and_mul_kernel",
|
"silu_and_mul_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
@@ -46,3 +46,69 @@ void silu_and_mul(
|
|||||||
d);
|
d);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Element-wise activation kernel template.
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void activation_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||||
|
const scalar_t* __restrict__ input, // [num_tokens, d]
|
||||||
|
const int d) {
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||||
|
out[token_idx * d + idx] = ACT_FN(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// Launch element-wise activation kernel.
|
||||||
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
|
int num_tokens = input.size(0); \
|
||||||
|
int d = input.size(1); \
|
||||||
|
dim3 grid(num_tokens); \
|
||||||
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), \
|
||||||
|
"activation_kernel", \
|
||||||
|
[&] { \
|
||||||
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<scalar_t>(), \
|
||||||
|
input.data_ptr<scalar_t>(), \
|
||||||
|
d); \
|
||||||
|
});
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||||
|
const float x3 = (float) (x * x * x);
|
||||||
|
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
|
const float f = (float) x;
|
||||||
|
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
|
|||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
|
|||||||
@@ -74,15 +74,20 @@ template<
|
|||||||
__global__ void single_query_cached_kv_attention_kernel(
|
__global__ void single_query_cached_kv_attention_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
const int* __restrict__ head_mapping, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride) {
|
const int q_stride,
|
||||||
|
const int kv_block_stride,
|
||||||
|
const int kv_head_stride) {
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
|
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||||
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
@@ -91,6 +96,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
const int head_idx = blockIdx.x;
|
const int head_idx = blockIdx.x;
|
||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
|
const int kv_head_idx = head_mapping[head_idx];
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||||
|
|
||||||
@@ -116,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
// th vectors of the query, and so on.
|
// th vectors of the query, and so on.
|
||||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||||
|
|
||||||
// Memory planning.
|
// Memory planning.
|
||||||
extern __shared__ char shared_mem[];
|
extern __shared__ char shared_mem[];
|
||||||
@@ -158,8 +165,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE
|
+ kv_head_idx * kv_head_stride
|
||||||
+ physical_block_offset * x;
|
+ physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
@@ -169,9 +176,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
// This includes a reduction across the threads in the same thread group.
|
// This includes a reduction across the threads in the same thread group.
|
||||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||||
// Add the ALiBi bias if slopes are given.
|
// Add the ALiBi bias if slopes are given.
|
||||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||||
|
|
||||||
if (thread_group_offset == 0) {
|
if (thread_group_offset == 0) {
|
||||||
// Store the partial reductions to shared memory.
|
// Store the partial reductions to shared memory.
|
||||||
@@ -239,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
accs[i] = 0.f;
|
accs[i] = 0.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scalar_t zero_value;
|
||||||
|
zero(zero_value);
|
||||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||||
const int physical_block_number = block_table[block_idx];
|
const int physical_block_number = block_table[block_idx];
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
@@ -246,14 +255,24 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
|
+ kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
if (block_idx == num_blocks - 1) {
|
||||||
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
|
// we should explicitly zero out the values since they may contain NaNs.
|
||||||
|
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||||
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j <= V_VEC_SIZE; j++) {
|
||||||
|
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
accs[i] += dot(logits_vec, v_vec);
|
accs[i] += dot(logits_vec, v_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -322,18 +341,24 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||||
|
cudaFuncSetAttribute( \
|
||||||
|
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
value_cache_ptr, \
|
value_cache_ptr, \
|
||||||
|
head_mapping_ptr, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
context_lens_ptr, \
|
context_lens_ptr, \
|
||||||
max_num_blocks_per_seq, \
|
max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, \
|
alibi_slopes_ptr, \
|
||||||
query_stride);
|
q_stride, \
|
||||||
|
kv_block_stride, \
|
||||||
|
kv_head_stride);
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
@@ -345,6 +370,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
@@ -354,7 +380,9 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
int max_num_blocks_per_seq = block_tables.size(1);
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
int query_stride = query.stride(0);
|
int q_stride = query.stride(0);
|
||||||
|
int kv_block_stride = key_cache.stride(0);
|
||||||
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
|
||||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
assert(head_size % thread_group_size == 0);
|
assert(head_size % thread_group_size == 0);
|
||||||
@@ -368,6 +396,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
|
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@@ -375,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
int logits_size = padded_max_context_len * sizeof(float);
|
int logits_size = padded_max_context_len * sizeof(float);
|
||||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
|
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||||
|
// Keep that in sync with the logic here!
|
||||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||||
|
|
||||||
dim3 grid(num_heads, num_seqs);
|
dim3 grid(num_heads, num_seqs);
|
||||||
@@ -382,7 +413,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
switch (head_size) {
|
switch (head_size) {
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
||||||
// 32, 160, 192, 256.
|
// 32, 160, 192.
|
||||||
// case 32:
|
// case 32:
|
||||||
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||||
// break;
|
// break;
|
||||||
@@ -407,9 +438,9 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
// case 192:
|
// case 192:
|
||||||
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||||
// break;
|
// break;
|
||||||
// case 256:
|
case 256:
|
||||||
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||||
// break;
|
break;
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
break;
|
break;
|
||||||
@@ -422,6 +453,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
|
head_mapping, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
@@ -469,6 +501,7 @@ void single_query_cached_kv_attention(
|
|||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
|
torch::Tensor& head_mapping, // [num_heads]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
|
|||||||
@@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||||
|
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
@@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
|
|||||||
return sum(c);
|
return sum(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zero-out a vector.
|
|
||||||
inline __device__ void zero(uint16_t& dst) {
|
|
||||||
dst = uint16_t(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// From float32 to float16.
|
// From float32 to float16.
|
||||||
inline __device__ void from_float(uint16_t& dst, float src) {
|
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||||
dst = float_to_half(src);
|
dst = float_to_half(src);
|
||||||
@@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
|
|||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(uint16_t& dst) {
|
||||||
|
dst = uint16_t(0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
@@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
|
|||||||
return u;
|
return u;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(float& dst) {
|
||||||
|
dst = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <map>
|
#include <map>
|
||||||
@@ -125,9 +127,7 @@ void copy_blocks(
|
|||||||
dim3 grid(num_layers, num_pairs);
|
dim3 grid(num_layers, num_pairs);
|
||||||
dim3 block(std::min(1024, numel_per_block));
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
@@ -202,9 +202,7 @@ void reshape_and_cache(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"reshape_and_cache_kernel",
|
"reshape_and_cache_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
@@ -364,9 +362,7 @@ void gather_cached_kv(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"gather_cached_kv_kernel_optimized",
|
"gather_cached_kv_kernel_optimized",
|
||||||
[&] {
|
[&] {
|
||||||
|
|||||||
13
csrc/cuda_utils.cpp
Normal file
13
csrc/cuda_utils.cpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"get_device_attribute",
|
||||||
|
&get_device_attribute,
|
||||||
|
"Gets the specified device attribute.");
|
||||||
|
}
|
||||||
|
|
||||||
14
csrc/cuda_utils_kernels.cu
Normal file
14
csrc/cuda_utils_kernels.cu
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id)
|
||||||
|
{
|
||||||
|
int device, value;
|
||||||
|
if (device_id < 0) {
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
device = device_id;
|
||||||
|
}
|
||||||
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
14
csrc/dispatch_utils.h
Normal file
14
csrc/dispatch_utils.h
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from
|
||||||
|
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||||
|
*/
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
#include "reduction_utils.cuh"
|
#include "reduction_utils.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -46,9 +47,7 @@ void rms_norm(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
"rms_norm_kernel",
|
"rms_norm_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions,
|
torch::Tensor& positions,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache);
|
torch::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding",
|
||||||
&rotary_embedding_neox,
|
&rotary_embedding,
|
||||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,51 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_neox_kernel(
|
inline __device__ void apply_rotary_embedding(
|
||||||
|
scalar_t* __restrict__ arr,
|
||||||
|
const scalar_t* __restrict__ cos_ptr,
|
||||||
|
const scalar_t* __restrict__ sin_ptr,
|
||||||
|
int rot_offset,
|
||||||
|
int embed_dim)
|
||||||
|
{
|
||||||
|
int x_index, y_index;
|
||||||
|
scalar_t cos, sin;
|
||||||
|
if (IS_NEOX) {
|
||||||
|
// GPT-NeoX style rotary embedding.
|
||||||
|
x_index = rot_offset;
|
||||||
|
y_index = embed_dim + rot_offset;
|
||||||
|
cos = __ldg(cos_ptr + x_index);
|
||||||
|
sin = __ldg(sin_ptr + x_index);
|
||||||
|
} else {
|
||||||
|
// GPT-J style rotary embedding.
|
||||||
|
x_index = 2 * rot_offset;
|
||||||
|
y_index = 2 * rot_offset + 1;
|
||||||
|
cos = __ldg(cos_ptr + x_index / 2);
|
||||||
|
sin = __ldg(sin_ptr + x_index / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const scalar_t x = arr[x_index];
|
||||||
|
const scalar_t y = arr[y_index];
|
||||||
|
arr[x_index] = x * cos - y * sin;
|
||||||
|
arr[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
const int rot_dim,
|
const int rot_dim,
|
||||||
const int stride,
|
const int query_stride,
|
||||||
|
const int key_stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
|
const int num_kv_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
@@ -19,65 +53,75 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const int n = num_heads * embed_dim;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
|
|
||||||
|
const int nq = num_heads * embed_dim;
|
||||||
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * stride + head_idx * head_size;
|
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
const int x_index = rot_offset;
|
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||||
const int y_index = embed_dim + rot_offset;
|
sin_ptr, rot_offset, embed_dim);
|
||||||
|
}
|
||||||
|
|
||||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
|
const int head_idx = i / embed_dim;
|
||||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
const int rot_offset = i % embed_dim;
|
||||||
|
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||||
const scalar_t q_x = query[token_head + x_index];
|
sin_ptr, rot_offset, embed_dim);
|
||||||
const scalar_t q_y = query[token_head + y_index];
|
|
||||||
query[out_x] = q_x * cos - q_y * sin;
|
|
||||||
query[out_y] = q_y * cos + q_x * sin;
|
|
||||||
|
|
||||||
const scalar_t k_x = key[token_head + x_index];
|
|
||||||
const scalar_t k_y = key[token_head + y_index];
|
|
||||||
key[out_x] = k_x * cos - k_y * sin;
|
|
||||||
key[out_y] = k_y * cos + k_x * sin;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
{
|
bool is_neox) {
|
||||||
int num_tokens = query.size(0);
|
int num_tokens = query.size(0);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
int stride = query.stride(0);
|
int num_kv_heads = key.size(1) / head_size;
|
||||||
TORCH_CHECK(stride == key.stride(0));
|
int query_stride = query.stride(0);
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
query.scalar_type(),
|
query.scalar_type(),
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding",
|
||||||
[&] {
|
[&] {
|
||||||
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
if (is_neox) {
|
||||||
positions.data_ptr<int64_t>(),
|
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
rot_dim,
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
stride,
|
rot_dim,
|
||||||
num_heads,
|
query_stride,
|
||||||
head_size);
|
key_stride,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size);
|
||||||
|
} else {
|
||||||
|
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rot_dim,
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
15
csrc/quantization.cpp
Normal file
15
csrc/quantization.cpp
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"awq_gemm",
|
||||||
|
&awq_gemm,
|
||||||
|
"Quantized GEMM for AWQ");
|
||||||
|
}
|
||||||
87
csrc/quantization/awq/dequantize.cuh
Normal file
87
csrc/quantization/awq/dequantize.cuh
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace awq {
|
||||||
|
|
||||||
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
uint4 result;
|
||||||
|
|
||||||
|
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||||
|
|
||||||
|
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||||
|
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||||
|
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||||
|
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||||
|
|
||||||
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||||
|
// immediately before required.
|
||||||
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[0])
|
||||||
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[1])
|
||||||
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
|
||||||
|
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||||
|
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||||
|
|
||||||
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
|
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||||
|
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||||
|
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||||
|
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||||
|
// This is the half2 {-72, -72} represented as an integer.
|
||||||
|
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||||
|
// Haotian: Let's use {-64, -64}.
|
||||||
|
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||||
|
|
||||||
|
// Finally, we construct the output numbers.
|
||||||
|
// Convert elt_01
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_23
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
// Convert elt_45
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_67
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace awq
|
||||||
|
} // namespace vllm
|
||||||
491
csrc/quantization/awq/gemm_kernels.cu
Normal file
491
csrc/quantization/awq/gemm_kernels.cu
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "dequantize.cuh"
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace awq {
|
||||||
|
|
||||||
|
// Pack two half values.
|
||||||
|
static inline __device__ __host__ unsigned
|
||||||
|
__pack_half2(const half x, const half y) {
|
||||||
|
unsigned v0 = *((unsigned short *)&x);
|
||||||
|
unsigned v1 = *((unsigned short *)&y);
|
||||||
|
return (v1 << 16) | v0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (128 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[128];
|
||||||
|
__shared__ half zeros_shared[128];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[32];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (128 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||||
|
+ ((int)threadIdx.y) * 64
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (64 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[64];
|
||||||
|
__shared__ half zeros_shared[64];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||||
|
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[16];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (64 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||||
|
+ ((int)threadIdx.y) * 32
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||||
|
{
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace awq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// in_feats: M, IC [float16]
|
||||||
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||||
|
// scaling_factors: IC // G, OC [float16]
|
||||||
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters)
|
||||||
|
{
|
||||||
|
int num_in_feats = _in_feats.size(0);
|
||||||
|
int num_in_channels = _in_feats.size(1);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||||
|
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
|
int num_out_feats = _out_feats.size(-2);
|
||||||
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
|
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
|
if (num_out_channels % 64 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||||
|
if (num_out_channels % 8 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||||
|
if (group_size % 32 != 0)
|
||||||
|
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||||
|
if (num_out_channels % group_size != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
|
if (num_out_channels % 128 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
else if (num_out_channels % 64 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
return _out_feats.sum(0);
|
||||||
|
}
|
||||||
@@ -3,31 +3,15 @@
|
|||||||
Installation
|
Installation
|
||||||
============
|
============
|
||||||
|
|
||||||
vLLM is a Python library that also contains some C++ and CUDA code.
|
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||||
This additional code requires compilation on the user's machine.
|
|
||||||
|
|
||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
|
|
||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Python: 3.8 or higher
|
* Python: 3.8 -- 3.11
|
||||||
* CUDA: 11.0 -- 11.8
|
|
||||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||||
|
|
||||||
.. note::
|
|
||||||
As of now, vLLM does not support CUDA 12.
|
|
||||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ # Pull the Docker image with CUDA 11.8.
|
|
||||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
|
||||||
|
|
||||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
|
||||||
|
|
||||||
Install with pip
|
Install with pip
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@@ -40,7 +24,7 @@ You can install vLLM using pip:
|
|||||||
$ conda activate myenv
|
$ conda activate myenv
|
||||||
|
|
||||||
$ # Install vLLM.
|
$ # Install vLLM.
|
||||||
$ pip install vllm # This may take 5-10 minutes.
|
$ pip install vllm
|
||||||
|
|
||||||
|
|
||||||
.. _build_from_source:
|
.. _build_from_source:
|
||||||
@@ -55,3 +39,12 @@ You can also build and install vLLM from source:
|
|||||||
$ git clone https://github.com/vllm-project/vllm.git
|
$ git clone https://github.com/vllm-project/vllm.git
|
||||||
$ cd vllm
|
$ cd vllm
|
||||||
$ pip install -e . # This may take 5-10 minutes.
|
$ pip install -e . # This may take 5-10 minutes.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # Pull the Docker image with CUDA 11.8.
|
||||||
|
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||||
|
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
||||||
|
|||||||
@@ -128,4 +128,4 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
|||||||
prompt="San Francisco is a")
|
prompt="San Francisco is a")
|
||||||
print("Completion result:", completion)
|
print("Completion result:", completion)
|
||||||
|
|
||||||
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_.
|
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ vLLM is flexible and easy to use with:
|
|||||||
For more information, check out the following:
|
For more information, check out the following:
|
||||||
|
|
||||||
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
||||||
|
* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
|
||||||
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
||||||
|
|
||||||
|
|
||||||
@@ -62,6 +63,8 @@ Documentation
|
|||||||
:caption: Serving
|
:caption: Serving
|
||||||
|
|
||||||
serving/distributed_serving
|
serving/distributed_serving
|
||||||
|
serving/run_on_sky
|
||||||
|
serving/deploying_with_triton
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
|||||||
+ kv_caches: List[KVCache],
|
+ kv_caches: List[KVCache],
|
||||||
+ input_metadata: InputMetadata,
|
+ input_metadata: InputMetadata,
|
||||||
+ cache_events: Optional[List[torch.cuda.Event]],
|
+ cache_events: Optional[List[torch.cuda.Event]],
|
||||||
+) -> Dict[int, SequenceOutputs]:
|
+) -> SamplerOutput:
|
||||||
|
|
||||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||||
|
|||||||
@@ -14,27 +14,48 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - Architecture
|
* - Architecture
|
||||||
- Models
|
- Models
|
||||||
- Example HuggingFace Models
|
- Example HuggingFace Models
|
||||||
|
* - :code:`AquilaForCausalLM`
|
||||||
|
- Aquila
|
||||||
|
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||||
|
* - :code:`BaiChuanForCausalLM`
|
||||||
|
- Baichuan
|
||||||
|
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||||
* - :code:`BloomForCausalLM`
|
* - :code:`BloomForCausalLM`
|
||||||
- BLOOM, BLOOMZ, BLOOMChat
|
- BLOOM, BLOOMZ, BLOOMChat
|
||||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||||
|
* - :code:`FalconForCausalLM`
|
||||||
|
- Falcon
|
||||||
|
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||||
* - :code:`GPT2LMHeadModel`
|
* - :code:`GPT2LMHeadModel`
|
||||||
- GPT-2
|
- GPT-2
|
||||||
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||||
* - :code:`GPTBigCodeForCausalLM`
|
* - :code:`GPTBigCodeForCausalLM`
|
||||||
- StarCoder, SantaCoder, WizardCoder
|
- StarCoder, SantaCoder, WizardCoder
|
||||||
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
|
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
|
||||||
|
* - :code:`GPTJForCausalLM`
|
||||||
|
- GPT-J
|
||||||
|
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
|
||||||
* - :code:`GPTNeoXForCausalLM`
|
* - :code:`GPTNeoXForCausalLM`
|
||||||
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||||
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||||
|
* - :code:`InternLMForCausalLM`
|
||||||
|
- InternLM
|
||||||
|
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
|
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||||
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
|
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||||
|
* - :code:`MistralForCausalLM`
|
||||||
|
- Mistral, Mistral-Instruct
|
||||||
|
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||||
* - :code:`MPTForCausalLM`
|
* - :code:`MPTForCausalLM`
|
||||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||||
* - :code:`OPTForCausalLM`
|
* - :code:`OPTForCausalLM`
|
||||||
- OPT, OPT-IML
|
- OPT, OPT-IML
|
||||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||||
|
* - :code:`QWenLMHeadModel`
|
||||||
|
- Qwen
|
||||||
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
|
|
||||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||||
|
|||||||
6
docs/source/serving/deploying_with_triton.rst
Normal file
6
docs/source/serving/deploying_with_triton.rst
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.. _deploying_with_triton:
|
||||||
|
|
||||||
|
Deploying with NVIDIA Triton
|
||||||
|
============================
|
||||||
|
|
||||||
|
The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.
|
||||||
69
docs/source/serving/run_on_sky.rst
Normal file
69
docs/source/serving/run_on_sky.rst
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
.. _on_cloud:
|
||||||
|
|
||||||
|
Running on clouds with SkyPilot
|
||||||
|
===============================
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
|
||||||
|
|
||||||
|
To install SkyPilot and setup your cloud credentials, run:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install skypilot
|
||||||
|
$ sky check
|
||||||
|
|
||||||
|
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
resources:
|
||||||
|
accelerators: A100
|
||||||
|
|
||||||
|
envs:
|
||||||
|
MODEL_NAME: decapoda-research/llama-13b-hf
|
||||||
|
TOKENIZER: hf-internal-testing/llama-tokenizer
|
||||||
|
|
||||||
|
setup: |
|
||||||
|
conda create -n vllm python=3.9 -y
|
||||||
|
conda activate vllm
|
||||||
|
git clone https://github.com/vllm-project/vllm.git
|
||||||
|
cd vllm
|
||||||
|
pip install .
|
||||||
|
pip install gradio
|
||||||
|
|
||||||
|
run: |
|
||||||
|
conda activate vllm
|
||||||
|
echo 'Starting vllm api server...'
|
||||||
|
python -u -m vllm.entrypoints.api_server \
|
||||||
|
--model $MODEL_NAME \
|
||||||
|
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||||
|
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
|
||||||
|
echo 'Waiting for vllm api server to start...'
|
||||||
|
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
|
||||||
|
echo 'Starting gradio server...'
|
||||||
|
python vllm/examples/gradio_webserver.py
|
||||||
|
|
||||||
|
Start the serving the LLaMA-13B model on an A100 GPU:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ sky launch serving.yaml
|
||||||
|
|
||||||
|
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
|
||||||
|
|
||||||
|
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf
|
||||||
|
|
||||||
@@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# Test the following prompts.
|
# Test the following prompts.
|
||||||
test_prompts = [
|
test_prompts = [
|
||||||
("A robot may not injure a human being", SamplingParams()),
|
("A robot may not injure a human being",
|
||||||
|
SamplingParams(temperature=0.0)),
|
||||||
("To be or not to be,",
|
("To be or not to be,",
|
||||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||||
("What is the meaning of life?",
|
("What is the meaning of life?",
|
||||||
@@ -27,7 +28,7 @@ def main(args: argparse.Namespace):
|
|||||||
# Run the engine by calling `engine.step()` manually.
|
# Run the engine by calling `engine.step()` manually.
|
||||||
request_id = 0
|
request_id = 0
|
||||||
while True:
|
while True:
|
||||||
# To test iteration-level scheduling, we add one request at each step.
|
# To test continuous batching, we add one request at each step.
|
||||||
if test_prompts:
|
if test_prompts:
|
||||||
prompt, sampling_params = test_prompts.pop(0)
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
engine.add_request(str(request_id), prompt, sampling_params)
|
engine.add_request(str(request_id), prompt, sampling_params)
|
||||||
|
|||||||
33
examples/openai_chatcompletion_client.py
Normal file
33
examples/openai_chatcompletion_client.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import openai
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai.api_key = "EMPTY"
|
||||||
|
openai.api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
# List models API
|
||||||
|
models = openai.Model.list()
|
||||||
|
print("Models:", models)
|
||||||
|
|
||||||
|
model = models["data"][0]["id"]
|
||||||
|
|
||||||
|
# Chat completion API
|
||||||
|
chat_completion = openai.ChatCompletion.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who won the world series in 2020?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content":
|
||||||
|
"The Los Angeles Dodgers won the World Series in 2020."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Where was it played?"
|
||||||
|
}])
|
||||||
|
|
||||||
|
print("Chat completion results:")
|
||||||
|
print(chat_completion)
|
||||||
@@ -3,26 +3,26 @@ import openai
|
|||||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
openai.api_key = "EMPTY"
|
openai.api_key = "EMPTY"
|
||||||
openai.api_base = "http://localhost:8000/v1"
|
openai.api_base = "http://localhost:8000/v1"
|
||||||
model = "facebook/opt-125m"
|
|
||||||
|
|
||||||
# Test list models API
|
# List models API
|
||||||
models = openai.Model.list()
|
models = openai.Model.list()
|
||||||
print("Models:", models)
|
print("Models:", models)
|
||||||
|
|
||||||
# Test completion API
|
model = models["data"][0]["id"]
|
||||||
stream = True
|
|
||||||
|
# Completion API
|
||||||
|
stream = False
|
||||||
completion = openai.Completion.create(
|
completion = openai.Completion.create(
|
||||||
model=model,
|
model=model,
|
||||||
prompt="A robot may not injure a human being",
|
prompt="A robot may not injure a human being",
|
||||||
echo=False,
|
echo=False,
|
||||||
n=2,
|
n=2,
|
||||||
best_of=3,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=3)
|
logprobs=3)
|
||||||
|
|
||||||
# print the completion
|
print("Completion results:")
|
||||||
if stream:
|
if stream:
|
||||||
for c in completion:
|
for c in completion:
|
||||||
print(c)
|
print(c)
|
||||||
else:
|
else:
|
||||||
print("Completion result:", completion)
|
print(completion)
|
||||||
@@ -10,3 +10,5 @@ types-setuptools
|
|||||||
|
|
||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
|
pytest-forked
|
||||||
|
pytest-asyncio
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
ninja # For faster builds.
|
ninja # For faster builds.
|
||||||
psutil
|
psutil
|
||||||
ray
|
ray >= 2.5.1
|
||||||
|
pandas # Required for Ray data.
|
||||||
|
pyarrow # Required for Ray data.
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
torch >= 2.0.0
|
torch >= 2.0.0
|
||||||
transformers >= 4.28.0 # Required for LLaMA.
|
transformers >= 4.33.1 # Required for Code Llama.
|
||||||
xformers >= 0.0.19
|
xformers >= 0.0.22
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn[standard]
|
||||||
pydantic # Required for OpenAI server.
|
pydantic < 2 # Required for OpenAI server.
|
||||||
fschat # Required for OpenAI ChatCompletion Endpoint.
|
|
||||||
|
|||||||
160
setup.py
160
setup.py
@@ -3,6 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Set
|
from typing import List, Set
|
||||||
|
import warnings
|
||||||
|
|
||||||
from packaging.version import parse, Version
|
from packaging.version import parse, Version
|
||||||
import setuptools
|
import setuptools
|
||||||
@@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
|||||||
|
|
||||||
ROOT_DIR = os.path.dirname(__file__)
|
ROOT_DIR = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# Supported NVIDIA GPU architectures.
|
||||||
|
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
|
||||||
|
|
||||||
# Compiler flags.
|
# Compiler flags.
|
||||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||||
# TODO(woosuk): Should we use -O3?
|
# TODO(woosuk): Should we use -O3?
|
||||||
@@ -22,7 +26,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|||||||
|
|
||||||
if CUDA_HOME is None:
|
if CUDA_HOME is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.")
|
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||||
|
|
||||||
|
|
||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||||
@@ -38,32 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
|||||||
return nvcc_cuda_version
|
return nvcc_cuda_version
|
||||||
|
|
||||||
|
|
||||||
# Collect the compute capabilities of all available GPUs.
|
def get_torch_arch_list() -> Set[str]:
|
||||||
device_count = torch.cuda.device_count()
|
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||||
compute_capabilities: Set[int] = set()
|
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
||||||
for i in range(device_count):
|
# compiler to additionally include PTX code that can be runtime-compiled
|
||||||
major, minor = torch.cuda.get_device_capability(i)
|
# and executed on the 8.6 or newer architectures. While the PTX code will
|
||||||
if major < 7:
|
# not give the best performance on the newer architectures, it provides
|
||||||
raise RuntimeError(
|
# forward compatibility.
|
||||||
"GPUs with compute capability less than 7.0 are not supported.")
|
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
|
||||||
compute_capabilities.add(major * 10 + minor)
|
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||||
# If no GPU is available, add all supported compute capabilities.
|
if arch_list is None:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# List are separated by ; or space.
|
||||||
|
arch_list = arch_list.replace(" ", ";").split(";")
|
||||||
|
for arch in arch_list:
|
||||||
|
if arch not in valid_arch_strs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported CUDA arch ({arch}). "
|
||||||
|
f"Valid CUDA arch strings are: {valid_arch_strs}.")
|
||||||
|
return set(arch_list)
|
||||||
|
|
||||||
|
|
||||||
|
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||||
|
compute_capabilities = get_torch_arch_list()
|
||||||
if not compute_capabilities:
|
if not compute_capabilities:
|
||||||
compute_capabilities = {70, 75, 80, 86, 90}
|
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||||
# Add target compute capabilities to NVCC flags.
|
# GPUs on the current machine.
|
||||||
for capability in compute_capabilities:
|
device_count = torch.cuda.device_count()
|
||||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
for i in range(device_count):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
if major < 7:
|
||||||
|
raise RuntimeError(
|
||||||
|
"GPUs with compute capability below 7.0 are not supported.")
|
||||||
|
compute_capabilities.add(f"{major}.{minor}")
|
||||||
|
|
||||||
|
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||||
|
if not compute_capabilities:
|
||||||
|
# If no GPU is specified nor available, add all supported architectures
|
||||||
|
# based on the NVCC CUDA version.
|
||||||
|
compute_capabilities = set(SUPPORTED_ARCHS)
|
||||||
|
if nvcc_cuda_version < Version("11.1"):
|
||||||
|
compute_capabilities.remove("8.6")
|
||||||
|
if nvcc_cuda_version < Version("11.8"):
|
||||||
|
compute_capabilities.remove("8.9")
|
||||||
|
compute_capabilities.remove("9.0")
|
||||||
|
|
||||||
# Validate the NVCC CUDA version.
|
# Validate the NVCC CUDA version.
|
||||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
|
||||||
if nvcc_cuda_version < Version("11.0"):
|
if nvcc_cuda_version < Version("11.0"):
|
||||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
if nvcc_cuda_version < Version("11.1"):
|
||||||
raise RuntimeError(
|
if any(cc.startswith("8.6") for cc in compute_capabilities):
|
||||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
raise RuntimeError(
|
||||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||||
raise RuntimeError(
|
if nvcc_cuda_version < Version("11.8"):
|
||||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||||
|
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||||
|
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||||
|
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||||
|
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||||
|
# instead of 8.9.
|
||||||
|
warnings.warn(
|
||||||
|
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||||
|
"Targeting compute capability 8.0 instead.")
|
||||||
|
compute_capabilities = set(cc for cc in compute_capabilities
|
||||||
|
if not cc.startswith("8.9"))
|
||||||
|
compute_capabilities.add("8.0+PTX")
|
||||||
|
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||||
|
raise RuntimeError(
|
||||||
|
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||||
|
|
||||||
|
# Add target compute capabilities to NVCC flags.
|
||||||
|
for capability in compute_capabilities:
|
||||||
|
num = capability[0] + capability[2]
|
||||||
|
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||||
|
if capability.endswith("+PTX"):
|
||||||
|
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||||
|
|
||||||
# Use NVCC threads to parallelize the build.
|
# Use NVCC threads to parallelize the build.
|
||||||
if nvcc_cuda_version >= Version("11.2"):
|
if nvcc_cuda_version >= Version("11.2"):
|
||||||
@@ -76,7 +130,10 @@ ext_modules = []
|
|||||||
cache_extension = CUDAExtension(
|
cache_extension = CUDAExtension(
|
||||||
name="vllm.cache_ops",
|
name="vllm.cache_ops",
|
||||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(cache_extension)
|
ext_modules.append(cache_extension)
|
||||||
|
|
||||||
@@ -84,7 +141,10 @@ ext_modules.append(cache_extension)
|
|||||||
attention_extension = CUDAExtension(
|
attention_extension = CUDAExtension(
|
||||||
name="vllm.attention_ops",
|
name="vllm.attention_ops",
|
||||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(attention_extension)
|
ext_modules.append(attention_extension)
|
||||||
|
|
||||||
@@ -92,7 +152,10 @@ ext_modules.append(attention_extension)
|
|||||||
positional_encoding_extension = CUDAExtension(
|
positional_encoding_extension = CUDAExtension(
|
||||||
name="vllm.pos_encoding_ops",
|
name="vllm.pos_encoding_ops",
|
||||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(positional_encoding_extension)
|
ext_modules.append(positional_encoding_extension)
|
||||||
|
|
||||||
@@ -100,7 +163,10 @@ ext_modules.append(positional_encoding_extension)
|
|||||||
layernorm_extension = CUDAExtension(
|
layernorm_extension = CUDAExtension(
|
||||||
name="vllm.layernorm_ops",
|
name="vllm.layernorm_ops",
|
||||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(layernorm_extension)
|
ext_modules.append(layernorm_extension)
|
||||||
|
|
||||||
@@ -108,10 +174,38 @@ ext_modules.append(layernorm_extension)
|
|||||||
activation_extension = CUDAExtension(
|
activation_extension = CUDAExtension(
|
||||||
name="vllm.activation_ops",
|
name="vllm.activation_ops",
|
||||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(activation_extension)
|
ext_modules.append(activation_extension)
|
||||||
|
|
||||||
|
# Quantization kernels.
|
||||||
|
quantization_extension = CUDAExtension(
|
||||||
|
name="vllm.quantization_ops",
|
||||||
|
sources=[
|
||||||
|
"csrc/quantization.cpp",
|
||||||
|
"csrc/quantization/awq/gemm_kernels.cu",
|
||||||
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ext_modules.append(quantization_extension)
|
||||||
|
|
||||||
|
# Misc. CUDA utils.
|
||||||
|
cuda_utils_extension = CUDAExtension(
|
||||||
|
name="vllm.cuda_utils",
|
||||||
|
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ext_modules.append(cuda_utils_extension)
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
return os.path.join(ROOT_DIR, *filepath)
|
return os.path.join(ROOT_DIR, *filepath)
|
||||||
@@ -123,8 +217,8 @@ def find_version(filepath: str):
|
|||||||
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
||||||
"""
|
"""
|
||||||
with open(filepath) as fp:
|
with open(filepath) as fp:
|
||||||
version_match = re.search(
|
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
|
||||||
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
|
fp.read(), re.M)
|
||||||
if version_match:
|
if version_match:
|
||||||
return version_match.group(1)
|
return version_match.group(1)
|
||||||
raise RuntimeError("Unable to find version string.")
|
raise RuntimeError("Unable to find version string.")
|
||||||
@@ -147,7 +241,8 @@ setuptools.setup(
|
|||||||
version=find_version(get_path("vllm", "__init__.py")),
|
version=find_version(get_path("vllm", "__init__.py")),
|
||||||
author="vLLM Team",
|
author="vLLM Team",
|
||||||
license="Apache 2.0",
|
license="Apache 2.0",
|
||||||
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
|
description=("A high-throughput and memory-efficient inference and "
|
||||||
|
"serving engine for LLMs"),
|
||||||
long_description=read_readme(),
|
long_description=read_readme(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
url="https://github.com/vllm-project/vllm",
|
url="https://github.com/vllm-project/vllm",
|
||||||
@@ -159,11 +254,12 @@ setuptools.setup(
|
|||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
],
|
],
|
||||||
packages=setuptools.find_packages(
|
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
|
||||||
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
|
"examples", "tests")),
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
|
|||||||
50
tests/async_engine/api_server_async_engine.py
Normal file
50
tests/async_engine/api_server_async_engine.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||||
|
import argparse
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
import vllm.entrypoints.api_server
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
|
||||||
|
app = vllm.entrypoints.api_server.app
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._num_aborts = 0
|
||||||
|
|
||||||
|
async def abort(self, request_id: str) -> None:
|
||||||
|
await super().abort(request_id)
|
||||||
|
self._num_aborts += 1
|
||||||
|
|
||||||
|
def testing_stats(self) -> Dict[str, Any]:
|
||||||
|
return {"num_aborted_requests": self._num_aborts}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
def stats() -> Response:
|
||||||
|
"""Get the statistics of the engine."""
|
||||||
|
return JSONResponse(engine.testing_stats())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||||
|
vllm.entrypoints.api_server.engine = engine
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level="debug",
|
||||||
|
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
|
||||||
86
tests/async_engine/test_api_server.py
Normal file
86
tests/async_engine/test_api_server.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def _query_server(prompt: str) -> dict:
|
||||||
|
response = requests.post("http://localhost:8000/generate",
|
||||||
|
json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"temperature": 0,
|
||||||
|
"ignore_eos": True
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_server():
|
||||||
|
script_path = Path(__file__).parent.joinpath(
|
||||||
|
"api_server_async_engine.py").absolute()
|
||||||
|
uvicorn_process = subprocess.Popen([
|
||||||
|
sys.executable, "-u",
|
||||||
|
str(script_path), "--model", "facebook/opt-125m"
|
||||||
|
])
|
||||||
|
yield
|
||||||
|
uvicorn_process.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_server(api_server):
|
||||||
|
"""
|
||||||
|
Run the API server and test it.
|
||||||
|
|
||||||
|
We run both the server and requests in separate processes.
|
||||||
|
|
||||||
|
We test that the server can handle incoming requests, including
|
||||||
|
multiple requests at the same time, and that it can handle requests
|
||||||
|
being cancelled without crashing.
|
||||||
|
"""
|
||||||
|
with Pool(32) as pool:
|
||||||
|
# Wait until the server is ready
|
||||||
|
prompts = ["Hello world"] * 1
|
||||||
|
result = None
|
||||||
|
while not result:
|
||||||
|
try:
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Actual tests start here
|
||||||
|
# Try with 1 prompt
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
assert result
|
||||||
|
|
||||||
|
num_aborted_requests = requests.get(
|
||||||
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||||
|
assert num_aborted_requests == 0
|
||||||
|
|
||||||
|
# Try with 100 prompts
|
||||||
|
prompts = ["Hello world"] * 100
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
assert result
|
||||||
|
|
||||||
|
# Cancel requests
|
||||||
|
pool.map_async(_query_server, prompts)
|
||||||
|
time.sleep(0.01)
|
||||||
|
pool.terminate()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
# check cancellation stats
|
||||||
|
num_aborted_requests = requests.get(
|
||||||
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||||
|
assert num_aborted_requests > 0
|
||||||
|
|
||||||
|
# check that server still runs after cancellations
|
||||||
|
with Pool(32) as pool:
|
||||||
|
# Try with 100 prompts
|
||||||
|
prompts = ["Hello world"] * 100
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
assert result
|
||||||
80
tests/async_engine/test_async_llm_engine.py
Normal file
80
tests/async_engine/test_async_llm_engine.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestOutput:
|
||||||
|
request_id: int
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MockEngine:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.step_calls = 0
|
||||||
|
self.add_request_calls = 0
|
||||||
|
self.abort_request_calls = 0
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
async def step_async(self):
|
||||||
|
self.step_calls += 1
|
||||||
|
return [RequestOutput(
|
||||||
|
request_id=self.request_id)] if self.request_id else []
|
||||||
|
|
||||||
|
def generate(self, request_id):
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
def stop_generating(self):
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
def add_request(self, **kwargs):
|
||||||
|
self.add_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def abort_request(self, request_id):
|
||||||
|
self.abort_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||||
|
|
||||||
|
def _init_engine(self, *args, **kwargs):
|
||||||
|
return MockEngine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_requests_event():
|
||||||
|
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||||
|
engine.start_background_loop()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.step_calls == 0
|
||||||
|
|
||||||
|
await engine.add_request("1", "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 1
|
||||||
|
assert engine.engine.step_calls == 1
|
||||||
|
|
||||||
|
await engine.add_request("2", "", None)
|
||||||
|
engine.engine.generate("2")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.add_request_calls == 2
|
||||||
|
assert engine.engine.step_calls == 2
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 3
|
||||||
|
engine.engine.stop_generating()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
|
||||||
|
await engine.add_request("3", "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
||||||
75
tests/async_engine/test_request_tracker.py
Normal file
75
tests/async_engine/test_request_tracker.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.async_llm_engine import RequestTracker
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEvent:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._flag = False
|
||||||
|
|
||||||
|
def set(self):
|
||||||
|
self._flag = True
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._flag = False
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_tracker():
|
||||||
|
tracker = RequestTracker()
|
||||||
|
tracker.new_requests_event = DummyEvent()
|
||||||
|
stream_1 = tracker.add_request("1")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0]["request_id"] == "1"
|
||||||
|
assert not finished
|
||||||
|
assert not stream_1.finished
|
||||||
|
|
||||||
|
stream_2 = tracker.add_request("2")
|
||||||
|
stream_3 = tracker.add_request("3")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
|
assert len(new) == 2
|
||||||
|
assert new[0]["request_id"] == "2"
|
||||||
|
assert new[1]["request_id"] == "3"
|
||||||
|
assert not finished
|
||||||
|
assert not stream_2.finished
|
||||||
|
assert not stream_3.finished
|
||||||
|
|
||||||
|
# request_ids must be unique
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
tracker.add_request("1")
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
|
|
||||||
|
tracker.abort_request("1")
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert "1" in finished
|
||||||
|
assert not new
|
||||||
|
assert stream_1.finished
|
||||||
|
|
||||||
|
stream_4 = tracker.add_request("4")
|
||||||
|
tracker.abort_request("4")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert "4" in finished
|
||||||
|
assert not new
|
||||||
|
assert stream_4.finished
|
||||||
|
|
||||||
|
stream_5 = tracker.add_request("5")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
|
tracker.process_request_output(
|
||||||
|
RequestOutput("2", "output", [], [], finished=True))
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert "2" in finished
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0]["request_id"] == "5"
|
||||||
|
assert stream_2.finished
|
||||||
|
assert not stream_5.finished
|
||||||
178
tests/conftest.py
Normal file
178
tests/conftest.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
_TEST_PROMPTS = [
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||||
|
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||||
|
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||||
|
"Describe the basic components of a neural network and how it can be trained.",
|
||||||
|
"Write a short story about a robot that dreams for the first time.",
|
||||||
|
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||||
|
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||||
|
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_prompts() -> List[str]:
|
||||||
|
return _TEST_PROMPTS
|
||||||
|
|
||||||
|
|
||||||
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float": torch.float,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HfRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer_name: Optional[str] = None,
|
||||||
|
dtype: str = "half",
|
||||||
|
) -> None:
|
||||||
|
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).cuda()
|
||||||
|
if tokenizer_name is None:
|
||||||
|
tokenizer_name = model_name
|
||||||
|
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
outputs: List[Tuple[List[int], str]] = []
|
||||||
|
for prompt in prompts:
|
||||||
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
input_ids.cuda(),
|
||||||
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
output_str = self.tokenizer.batch_decode(
|
||||||
|
output_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
)
|
||||||
|
output_ids = output_ids.cpu().tolist()
|
||||||
|
outputs.append((output_ids, output_str))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_greedy(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
outputs = self.generate(prompts,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=max_tokens)
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
output_ids, output_str = outputs[i]
|
||||||
|
outputs[i] = (output_ids[0], output_str[0])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_beam_search(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
beam_width: int,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
outputs = self.generate(prompts,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
num_beams=beam_width,
|
||||||
|
num_return_sequences=beam_width)
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
output_ids, output_str = outputs[i]
|
||||||
|
for j in range(len(output_ids)):
|
||||||
|
output_ids[j] = [
|
||||||
|
x for x in output_ids[j]
|
||||||
|
if x != self.tokenizer.pad_token_id
|
||||||
|
]
|
||||||
|
outputs[i] = (output_ids, output_str)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hf_runner():
|
||||||
|
return HfRunner
|
||||||
|
|
||||||
|
|
||||||
|
class VllmRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer_name: Optional[str] = None,
|
||||||
|
dtype: str = "half",
|
||||||
|
) -> None:
|
||||||
|
self.model = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tokenizer=tokenizer_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=dtype,
|
||||||
|
swap_space=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
req_outputs = self.model.generate(prompts,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
outputs = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
prompt_str = req_output.prompt
|
||||||
|
prompt_ids = req_output.prompt_token_ids
|
||||||
|
req_sample_output_ids = []
|
||||||
|
req_sample_output_strs = []
|
||||||
|
for sample in req_output.outputs:
|
||||||
|
output_str = sample.text
|
||||||
|
output_ids = sample.token_ids
|
||||||
|
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||||
|
req_sample_output_strs.append(prompt_str + output_str)
|
||||||
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_greedy(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
|
outputs = self.generate(prompts, greedy_params)
|
||||||
|
return [(output_ids[0], output_str[0])
|
||||||
|
for output_ids, output_str in outputs]
|
||||||
|
|
||||||
|
def generate_beam_search(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
beam_width: int,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
beam_search_params = SamplingParams(n=beam_width,
|
||||||
|
use_beam_search=True,
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=max_tokens)
|
||||||
|
outputs = self.generate(prompts, beam_search_params)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vllm_runner():
|
||||||
|
return VllmRunner
|
||||||
62
tests/engine/test_detokenize.py
Normal file
62
tests/engine/test_detokenize.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||||
|
|
||||||
|
TRUTH = [
|
||||||
|
"Hello here, this is a simple test",
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||||
|
"我很感谢你的热情"
|
||||||
|
]
|
||||||
|
TOKENIZERS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"gpt2",
|
||||||
|
"bigcode/tiny_starcoder_py",
|
||||||
|
"EleutherAI/gpt-j-6b",
|
||||||
|
"EleutherAI/pythia-70m",
|
||||||
|
"bigscience/bloom-560m",
|
||||||
|
"mosaicml/mpt-7b",
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
"codellama/CodeLlama-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||||
|
skip_special_tokens: bool):
|
||||||
|
decoded_text = ""
|
||||||
|
offset = 0
|
||||||
|
token_offset = 0
|
||||||
|
prev_tokens = None
|
||||||
|
for i in range(len(all_input_ids)):
|
||||||
|
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||||
|
tokenizer,
|
||||||
|
all_input_ids[:i + 1],
|
||||||
|
prev_tokens,
|
||||||
|
offset,
|
||||||
|
token_offset,
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
decoded_text += text
|
||||||
|
if prev_tokens is None:
|
||||||
|
prev_tokens = new_tokens
|
||||||
|
else:
|
||||||
|
prev_tokens += new_tokens
|
||||||
|
return decoded_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("truth", TRUTH)
|
||||||
|
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||||
|
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||||
|
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||||
|
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
if skip_special_tokens:
|
||||||
|
all_input_ids = ([tokenizer.bos_token_id]
|
||||||
|
if tokenizer.bos_token_id is not None else
|
||||||
|
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
decoded_text = _run_incremental_decode(
|
||||||
|
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
assert decoded_text == truth
|
||||||
43
tests/kernels/conftest.py
Normal file
43
tests/kernels/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_kv_caches(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = head_size**-0.5
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
|
key_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
key_cache.uniform_(-scale, scale)
|
||||||
|
key_caches.append(key_cache)
|
||||||
|
|
||||||
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
|
value_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
value_cache.uniform_(-scale, scale)
|
||||||
|
value_caches.append(value_cache)
|
||||||
|
return key_caches, value_caches
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def kv_cache_factory():
|
||||||
|
return create_kv_caches
|
||||||
@@ -1,20 +1,34 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers.activations import get_activation
|
||||||
|
|
||||||
from vllm import activation_ops
|
from vllm import activation_ops
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||||
|
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||||
return F.silu(x1) * x2
|
return F.silu(x1) * x2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("d", D)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_silu_and_mul(
|
def test_silu_and_mul(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
d: int,
|
d: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
activation_ops.silu_and_mul(out, x)
|
activation_ops.silu_and_mul(out, x)
|
||||||
@@ -22,9 +36,40 @@ def run_silu_and_mul(
|
|||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_silu_and_mul() -> None:
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
@pytest.mark.parametrize("d", D)
|
||||||
for num_tokens in [7, 83, 2048]:
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
for d in [512, 4096, 5120, 13824]:
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
@torch.inference_mode()
|
||||||
run_silu_and_mul(num_tokens, d, dtype)
|
def test_gelu_new(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.gelu_new(out, x)
|
||||||
|
ref_out = get_activation("gelu_new")(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("d", D)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
def test_gelu_fast(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.gelu_fast(out, x)
|
||||||
|
ref_out = get_activation("gelu_fast")(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|||||||
@@ -1,14 +1,28 @@
|
|||||||
import random
|
import random
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
MAX_SEQ_LEN = 4096
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
TEST_SEED = 0
|
# This will change depending on the compute capability.
|
||||||
|
# - 512 as a buffer
|
||||||
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||||
|
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||||
|
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
|
||||||
|
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
USE_ALIBI = [False, True]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def ref_masked_attention(
|
def ref_masked_attention(
|
||||||
@@ -18,29 +32,34 @@ def ref_masked_attention(
|
|||||||
scale: float,
|
scale: float,
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
query = query * scale
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn = attn + attn_mask
|
attn_weights = attn_weights + attn_mask.float()
|
||||||
attn = torch.softmax(attn, dim=-1)
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def ref_single_query_cached_kv_attention(
|
def ref_single_query_cached_kv_attention(
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
num_queries_per_kv: int,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
num_heads = value_cache.shape[1]
|
num_query_heads = query.shape[1]
|
||||||
|
num_kv_heads = value_cache.shape[1]
|
||||||
head_size = value_cache.shape[2]
|
head_size = value_cache.shape[2]
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs = query.shape[0]
|
||||||
|
|
||||||
num_input_tokens = query.shape[0]
|
block_tables = block_tables.cpu().tolist()
|
||||||
for i in range(num_input_tokens):
|
context_lens = context_lens.cpu().tolist()
|
||||||
|
for i in range(num_seqs):
|
||||||
q = query[i].unsqueeze(0)
|
q = query[i].unsqueeze(0)
|
||||||
block_table = block_tables[i]
|
block_table = block_tables[i]
|
||||||
context_len = int(context_lens[i])
|
context_len = int(context_lens[i])
|
||||||
@@ -52,30 +71,139 @@ def ref_single_query_cached_kv_attention(
|
|||||||
block_offset = j % block_size
|
block_offset = j % block_size
|
||||||
|
|
||||||
k = key_cache[block_number, :, :, block_offset, :]
|
k = key_cache[block_number, :, :, block_offset, :]
|
||||||
k = k.reshape(num_heads, head_size)
|
k = k.reshape(num_kv_heads, head_size)
|
||||||
keys.append(k)
|
keys.append(k)
|
||||||
|
|
||||||
v = value_cache[block_number, :, :, block_offset]
|
v = value_cache[block_number, :, :, block_offset]
|
||||||
values.append(v)
|
values.append(v)
|
||||||
keys = torch.stack(keys, dim=0)
|
keys = torch.stack(keys, dim=0)
|
||||||
values = torch.stack(values, dim=0)
|
values = torch.stack(values, dim=0)
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||||
|
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||||
|
|
||||||
scale = 1.0 / (head_size**0.5)
|
alibi_bias = None
|
||||||
out = ref_masked_attention(q, keys, values, scale)
|
if alibi_slopes is not None:
|
||||||
out = out.view(num_heads, head_size)
|
# Create the ALiBi bias used in the paged attention kernel.
|
||||||
|
position_ids = torch.arange(context_len, device="cuda").int()
|
||||||
|
alibi_bias = (position_ids - context_len + 1).float()
|
||||||
|
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||||
|
1, 1, -1)
|
||||||
|
|
||||||
|
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||||
|
out = out.view(num_query_heads, head_size)
|
||||||
output[i].copy_(out, non_blocking=True)
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_single_query_cached_kv_attention(
|
||||||
|
kv_cache_factory,
|
||||||
|
num_seqs: int,
|
||||||
|
num_heads: Tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
|
query = torch.empty(num_seqs,
|
||||||
|
num_query_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda")
|
||||||
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
head_mapping = torch.repeat_interleave(
|
||||||
|
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||||
|
num_queries_per_kv)
|
||||||
|
alibi_slopes = None
|
||||||
|
if use_alibi:
|
||||||
|
alibi_slopes = torch.randn(num_query_heads,
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
context_lens[-1] = MAX_SEQ_LEN
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
|
# Create the block tables.
|
||||||
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_seqs):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
|
# Create the KV caches.
|
||||||
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||||
|
num_kv_heads, head_size, dtype,
|
||||||
|
seed)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
|
# Call the paged attention kernel.
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
attention_ops.single_query_cached_kv_attention(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
head_mapping,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the reference implementation.
|
||||||
|
ref_output = torch.empty_like(query)
|
||||||
|
ref_single_query_cached_kv_attention(
|
||||||
|
ref_output,
|
||||||
|
query,
|
||||||
|
num_queries_per_kv,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
scale,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||||
|
# implementations, there is a small numerical difference in the two
|
||||||
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||||
|
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
cu_seq_lens: List[int],
|
cu_seq_lens: List[int],
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
head_size = query.shape[-1]
|
|
||||||
scale = 1.0 / (head_size**0.5)
|
|
||||||
|
|
||||||
num_seqs = len(cu_seq_lens) - 1
|
num_seqs = len(cu_seq_lens) - 1
|
||||||
ref_outputs = []
|
ref_outputs = []
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
@@ -87,7 +215,7 @@ def ref_multi_query_kv_attention(
|
|||||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||||
diagonal=1)
|
diagonal=1)
|
||||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
|
||||||
|
|
||||||
ref_output = ref_masked_attention(
|
ref_output = ref_masked_attention(
|
||||||
query[start_idx:end_idx],
|
query[start_idx:end_idx],
|
||||||
@@ -101,161 +229,47 @@ def ref_multi_query_kv_attention(
|
|||||||
return ref_output
|
return ref_output
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_cached_kv_attention(
|
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||||
cu_query_lens: List[int],
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
query: torch.Tensor,
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
key_cache: torch.Tensor,
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
value_cache: torch.Tensor,
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
block_tables: torch.Tensor,
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
context_lens: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
num_heads = value_cache.shape[1]
|
|
||||||
head_size = value_cache.shape[2]
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
scale = 1.0 / (head_size**0.5)
|
|
||||||
|
|
||||||
num_queries = len(cu_query_lens) - 1
|
|
||||||
ref_outputs = []
|
|
||||||
for i in range(num_queries):
|
|
||||||
start_idx = cu_query_lens[i]
|
|
||||||
end_idx = cu_query_lens[i + 1]
|
|
||||||
query_len = end_idx - start_idx
|
|
||||||
context_len = int(context_lens[i])
|
|
||||||
block_table = block_tables[i]
|
|
||||||
|
|
||||||
# Create attention mask
|
|
||||||
attn_mask = torch.triu(torch.ones(query_len, context_len),
|
|
||||||
diagonal=context_len - query_len + 1) * -1e5
|
|
||||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
keys = []
|
|
||||||
values = []
|
|
||||||
for j in range(context_len):
|
|
||||||
block_number = int(block_table[j // block_size])
|
|
||||||
block_offset = j % block_size
|
|
||||||
|
|
||||||
k = key_cache[block_number, :, :, block_offset, :]
|
|
||||||
k = k.reshape(num_heads, head_size)
|
|
||||||
keys.append(k)
|
|
||||||
|
|
||||||
v = value_cache[block_number, :, :, block_offset]
|
|
||||||
values.append(v)
|
|
||||||
keys = torch.stack(keys, dim=0)
|
|
||||||
values = torch.stack(values, dim=0)
|
|
||||||
|
|
||||||
ref_output = ref_masked_attention(
|
|
||||||
query[start_idx:end_idx],
|
|
||||||
keys,
|
|
||||||
values,
|
|
||||||
scale,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
)
|
|
||||||
ref_outputs.append(ref_output)
|
|
||||||
ref_output = torch.cat(ref_outputs, dim=0)
|
|
||||||
return ref_output
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_single_query_cached_kv_attention(
|
def test_multi_query_kv_attention(
|
||||||
num_tokens: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
qkv = torch.empty(num_tokens,
|
|
||||||
3,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
qkv.uniform_(-1e-3, 1e-3)
|
|
||||||
query, _, _ = qkv.unbind(dim=1)
|
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
|
||||||
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
key_cache.uniform_(-1e-3, 1e-3)
|
|
||||||
value_block_shape = (num_heads, head_size, block_size)
|
|
||||||
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
value_cache.uniform_(-1e-3, 1e-3)
|
|
||||||
|
|
||||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
|
||||||
max_context_len = max(context_lens)
|
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
|
||||||
block_tables = []
|
|
||||||
for _ in range(num_tokens):
|
|
||||||
block_table = [
|
|
||||||
random.randint(0, num_blocks - 1)
|
|
||||||
for _ in range(max_num_blocks_per_seq)
|
|
||||||
]
|
|
||||||
block_tables.append(block_table)
|
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
|
||||||
output = torch.empty(num_tokens,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
attention_ops.single_query_cached_kv_attention(
|
|
||||||
output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
scale,
|
|
||||||
block_tables,
|
|
||||||
context_lens,
|
|
||||||
block_size,
|
|
||||||
max_context_len,
|
|
||||||
None, # ALiBi slopes.
|
|
||||||
)
|
|
||||||
|
|
||||||
ref_output = torch.empty_like(query)
|
|
||||||
ref_single_query_cached_kv_attention(
|
|
||||||
ref_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
block_tables,
|
|
||||||
context_lens,
|
|
||||||
)
|
|
||||||
# NOTE(woosuk): Due to the difference in the data types the two
|
|
||||||
# implementations use for attention softmax logits and accumulation,
|
|
||||||
# there is a small difference in the final outputs.
|
|
||||||
# We should use a relaxed tolerance for the test.
|
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def run_multi_query_kv_attention(
|
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
num_heads: int,
|
num_heads: Tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||||
|
# As the xformers library is already tested with its own tests, we can use
|
||||||
|
# a smaller MAX_SEQ_LEN here.
|
||||||
|
max_len = min(MAX_SEQ_LEN, 4096)
|
||||||
|
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||||
num_tokens = sum(seq_lens)
|
num_tokens = sum(seq_lens)
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
qkv = torch.empty(num_tokens,
|
qkv = torch.empty(num_tokens,
|
||||||
3,
|
num_query_heads + 2 * num_kv_heads,
|
||||||
num_heads,
|
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device='cuda')
|
device="cuda")
|
||||||
qkv.uniform_(-1e-3, 1e-3)
|
qkv.uniform_(-scale, scale)
|
||||||
query, key, value = qkv.unbind(dim=1)
|
query, key, value = qkv.split(
|
||||||
|
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||||
|
|
||||||
attn_op = xops.fmha.cutlass.FwOp()
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||||
output = xops.memory_efficient_attention_forward(
|
output = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
@@ -264,7 +278,6 @@ def run_multi_query_kv_attention(
|
|||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
op=attn_op,
|
|
||||||
)
|
)
|
||||||
output = output.squeeze(0)
|
output = output.squeeze(0)
|
||||||
|
|
||||||
@@ -276,40 +289,7 @@ def run_multi_query_kv_attention(
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
scale,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_single_query_cached_kv_attention() -> None:
|
|
||||||
torch.random.manual_seed(TEST_SEED)
|
|
||||||
torch.cuda.manual_seed(TEST_SEED)
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for block_size in [8, 16, 32]:
|
|
||||||
for head_size in [64, 80, 96, 128]:
|
|
||||||
print(f'Testing single_query_cached_kv_attention with '
|
|
||||||
f'dtype={dtype}, block_size={block_size}, '
|
|
||||||
f'head_size={head_size}')
|
|
||||||
run_single_query_cached_kv_attention(
|
|
||||||
num_tokens=37,
|
|
||||||
num_heads=3,
|
|
||||||
head_size=head_size,
|
|
||||||
block_size=block_size,
|
|
||||||
num_blocks=1024,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_query_kv_attention() -> None:
|
|
||||||
torch.random.manual_seed(TEST_SEED)
|
|
||||||
torch.cuda.manual_seed(TEST_SEED)
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for head_size in [64, 80, 96, 128]:
|
|
||||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
|
||||||
f'head_size={head_size}')
|
|
||||||
run_multi_query_kv_attention(
|
|
||||||
num_seqs=5,
|
|
||||||
num_heads=3,
|
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,12 +1,32 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||||
|
NUM_LAYERS = [5] # Arbitrary values for testing
|
||||||
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
NUM_BLOCKS = [1024] # Arbitrary values for testing
|
||||||
|
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||||
|
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_copy_blocks(
|
def test_copy_blocks(
|
||||||
|
kv_cache_factory,
|
||||||
num_mappings: int,
|
num_mappings: int,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@@ -14,48 +34,43 @@ def run_copy_blocks(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Generate random block mappings.
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Generate random block mappings where each source block is mapped to two
|
||||||
|
# destination blocks.
|
||||||
|
assert 2 * num_mappings <= num_blocks
|
||||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||||
dst_blocks = random.sample(remainig_blocks, num_mappings)
|
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||||
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
|
block_mapping = {}
|
||||||
|
for i in range(num_mappings):
|
||||||
|
src = src_blocks[i]
|
||||||
|
dst1 = dst_blocks[2 * i]
|
||||||
|
dst2 = dst_blocks[2 * i + 1]
|
||||||
|
block_mapping[src] = [dst1, dst2]
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV caches.
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
num_layers, num_heads,
|
||||||
key_caches = []
|
head_size, dtype, seed)
|
||||||
for _ in range(num_layers):
|
|
||||||
key_cache = torch.randn(size=key_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
key_caches.append(key_cache)
|
|
||||||
cloned_key_caches = []
|
|
||||||
for key_cache in key_caches:
|
|
||||||
cloned_key_caches.append(key_cache.clone())
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
# Clone the KV caches.
|
||||||
value_caches = []
|
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||||
for _ in range(num_layers):
|
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||||
value_cache = torch.randn(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
value_caches.append(value_cache)
|
|
||||||
cloned_value_caches = []
|
|
||||||
for value_cache in value_caches:
|
|
||||||
cloned_value_caches.append(value_cache.clone())
|
|
||||||
|
|
||||||
# Call the copy blocks kernel.
|
# Call the copy blocks kernel.
|
||||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
|
|
||||||
# Reference implementation.
|
# Run the reference implementation.
|
||||||
for src, dsts in block_mapping.items():
|
for src, dsts in block_mapping.items():
|
||||||
for dst in dsts:
|
for dst in dsts:
|
||||||
for key_cache, cloned_key_cache in zip(key_caches,
|
for cloned_key_cache in cloned_key_caches:
|
||||||
cloned_key_caches):
|
|
||||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||||
for value_cache, cloned_value_cache in zip(value_caches,
|
for cloned_value_cache in cloned_value_caches:
|
||||||
cloned_value_caches):
|
|
||||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
@@ -66,15 +81,29 @@ def run_copy_blocks(
|
|||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_reshape_and_cache(
|
def test_reshape_and_cache(
|
||||||
|
kv_cache_factory,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Create a random slot mapping.
|
||||||
num_slots = block_size * num_blocks
|
num_slots = block_size * num_blocks
|
||||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||||
@@ -87,110 +116,31 @@ def run_reshape_and_cache(
|
|||||||
device='cuda')
|
device='cuda')
|
||||||
_, key, value = qkv.unbind(dim=1)
|
_, key, value = qkv.unbind(dim=1)
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
# Create the KV caches.
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
num_heads, head_size, dtype,
|
||||||
cloned_key_cache = key_cache.clone()
|
seed)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
# Clone the KV caches.
|
||||||
value_cache = torch.randn(size=value_cache_shape,
|
cloned_key_cache = key_cache.clone()
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
|
# Call the reshape_and_cache kernel.
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||||
slot_mapping)
|
slot_mapping)
|
||||||
|
|
||||||
|
# Run the reference implementation.
|
||||||
|
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||||
|
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||||
|
block_indicies = block_indicies.cpu().tolist()
|
||||||
|
block_offsets = slot_mapping % block_size
|
||||||
|
block_offsets = block_offsets.cpu().tolist()
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
block_idx = block_indicies[i]
|
||||||
block_idx = torch.div(slot_mapping[i],
|
block_offset = block_offsets[i]
|
||||||
block_size,
|
|
||||||
rounding_mode='floor')
|
|
||||||
block_offset = slot_mapping[i] % block_size
|
|
||||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||||
|
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
assert torch.allclose(key_cache, cloned_key_cache)
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def run_gather_cached_kv(
|
|
||||||
num_tokens: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
num_slots = block_size * num_blocks
|
|
||||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
qkv = torch.randn(num_tokens,
|
|
||||||
3,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
_, key, value = qkv.unbind(dim=1)
|
|
||||||
|
|
||||||
qkv_clone = qkv.clone()
|
|
||||||
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
|
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
|
||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
|
||||||
value_cache = torch.randn(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
|
|
||||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
|
|
||||||
slot_mapping)
|
|
||||||
|
|
||||||
# Reference implementation.
|
|
||||||
for i in range(num_tokens):
|
|
||||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
|
|
||||||
head_size // x, x)
|
|
||||||
block_idx = torch.div(slot_mapping[i],
|
|
||||||
block_size,
|
|
||||||
rounding_mode='floor')
|
|
||||||
block_offset = slot_mapping[i] % block_size
|
|
||||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
|
||||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
|
||||||
|
|
||||||
assert torch.allclose(key, cloned_key)
|
|
||||||
assert torch.allclose(value, cloned_value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_copy_blocks() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
run_copy_blocks(num_mappings=23,
|
|
||||||
num_layers=7,
|
|
||||||
num_heads=17,
|
|
||||||
head_size=16,
|
|
||||||
block_size=8,
|
|
||||||
num_blocks=1024,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def test_reshape_and_cache() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
run_reshape_and_cache(num_tokens=3,
|
|
||||||
num_heads=2,
|
|
||||||
head_size=16,
|
|
||||||
block_size=8,
|
|
||||||
num_blocks=2,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gather_cached_kv() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
run_gather_cached_kv(num_tokens=3,
|
|
||||||
num_heads=2,
|
|
||||||
head_size=16,
|
|
||||||
block_size=8,
|
|
||||||
num_blocks=2,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|||||||
@@ -1,35 +1,50 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import layernorm_ops
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||||
|
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
class RefRMSNorm(nn.Module):
|
class RefRMSNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight = torch.empty(hidden_size)
|
weight = torch.empty(hidden_size)
|
||||||
weight.uniform_(-1e-3, 1e-3)
|
weight.normal_(mean=1.0, std=0.1)
|
||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
input_dtype = hidden_states.dtype
|
||||||
keepdim=True)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
return self.weight * hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_rms_norm(
|
def test_rms_norm(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = float(hidden_size**-0.5)
|
||||||
|
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
x.uniform_(-scale, scale)
|
||||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
@@ -40,17 +55,4 @@ def run_rms_norm(
|
|||||||
ref.variance_epsilon,
|
ref.variance_epsilon,
|
||||||
)
|
)
|
||||||
ref_out = ref(x)
|
ref_out = ref(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_rms_norm() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for num_tokens in [7, 128, 2048]:
|
|
||||||
for hidden_size in [13, 64, 1024, 5120]:
|
|
||||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
|
||||||
f'{num_tokens}, hidden_size={hidden_size}')
|
|
||||||
run_rms_norm(
|
|
||||||
num_tokens=num_tokens,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,47 +1,70 @@
|
|||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import pos_encoding_ops
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
IS_NEOX_STYLE = [True, False]
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||||
|
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
||||||
|
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
|
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = x[..., :x.shape[-1] // 2]
|
x1 = x[..., :x.shape[-1] // 2]
|
||||||
x2 = x[..., x.shape[-1] // 2:]
|
x2 = x[..., x.shape[-1] // 2:]
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
|
is_neox_style: bool,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
q_embed = (q * cos) + (rotate_fn(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_fn(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
class RefRotaryEmbeddingNeox(nn.Module):
|
class RefRotaryEmbedding(nn.Module):
|
||||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
"""Reference implementation of rotary embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
max_position_embeddings: int = 2048,
|
is_neox_style: bool,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rotary_dim = dim
|
self.rotary_dim = dim
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
# Create cos and sin embeddings.
|
# Create cos and sin embeddings.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||||
t = torch.arange(max_position_embeddings).float()
|
t = torch.arange(max_position_embeddings).float()
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
if is_neox_style:
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.repeat_interleave(freqs, 2, -1)
|
||||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||||
self.register_buffer("cos_cached", cos, persistent=False)
|
self.register_buffer("cos_cached", cos, persistent=False)
|
||||||
@@ -53,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
query_rot = query[..., :self.rotary_dim]
|
query_rot = query[..., :self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
key_rot = key[..., :self.rotary_dim]
|
key_rot = key[..., :self.rotary_dim]
|
||||||
@@ -63,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
key_rot = key_rot.transpose(0, 1)
|
key_rot = key_rot.transpose(0, 1)
|
||||||
cos = F.embedding(positions, self.cos_cached)
|
cos = F.embedding(positions, self.cos_cached)
|
||||||
sin = F.embedding(positions, self.sin_cached)
|
sin = F.embedding(positions, self.sin_cached)
|
||||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
||||||
|
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
|
||||||
|
self.is_neox_style)
|
||||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
@@ -74,30 +98,45 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_rotary_embedding_neox(
|
def test_rotary_embedding(
|
||||||
|
is_neox_style: bool,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
max_position: int,
|
rotary_dim: Optional[int],
|
||||||
rotary_dim: int,
|
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
|
||||||
query = torch.randn(num_tokens,
|
query = torch.randn(num_tokens,
|
||||||
num_heads * head_size,
|
num_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device='cuda')
|
device="cuda")
|
||||||
key = torch.randn(num_tokens,
|
key = torch.randn(num_tokens,
|
||||||
num_heads * head_size,
|
num_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device='cuda')
|
device="cuda")
|
||||||
|
|
||||||
# Create the rotary embedding.
|
# Create the rotary embedding.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(
|
||||||
|
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||||
@@ -106,20 +145,22 @@ def run_rotary_embedding_neox(
|
|||||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||||
out_query = query.clone()
|
out_query = query.clone()
|
||||||
out_key = key.clone()
|
out_key = key.clone()
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
out_query,
|
out_query,
|
||||||
out_key,
|
out_key,
|
||||||
head_size,
|
head_size,
|
||||||
cos_sin_cache,
|
cos_sin_cache,
|
||||||
|
is_neox_style,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
ref_rotary_embedding = RefRotaryEmbedding(
|
||||||
dim=rotary_dim,
|
dim=rotary_dim,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
max_position_embeddings=max_position,
|
max_position_embeddings=max_position,
|
||||||
base=base,
|
base=base,
|
||||||
).to(dtype=dtype, device='cuda')
|
).to(dtype=dtype, device="cuda")
|
||||||
ref_query, ref_key = ref_rotary_embedding(
|
ref_query, ref_key = ref_rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query.view(num_tokens, num_heads, head_size),
|
query.view(num_tokens, num_heads, head_size),
|
||||||
@@ -129,19 +170,5 @@ def run_rotary_embedding_neox(
|
|||||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_rotary_embedding_neox() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
|
||||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
|
||||||
run_rotary_embedding_neox(
|
|
||||||
num_tokens=2145,
|
|
||||||
num_heads=5,
|
|
||||||
head_size=head_size,
|
|
||||||
max_position=8192,
|
|
||||||
rotary_dim=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
45
tests/models/test_models.py
Normal file
45
tests/models/test_models.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_models.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"gpt2",
|
||||||
|
"bigcode/tiny_starcoder_py",
|
||||||
|
"EleutherAI/gpt-j-6b",
|
||||||
|
"EleutherAI/pythia-70m",
|
||||||
|
"bigscience/bloom-560m",
|
||||||
|
"mosaicml/mpt-7b",
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
46
tests/samplers/test_beam_search.py
Normal file
46
tests/samplers/test_beam_search.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM when using beam search.
|
||||||
|
|
||||||
|
Run `pytest tests/samplers/test_beam_search.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# FIXME(zhuohan): The test can not pass if we:
|
||||||
|
# 1. Increase max_tokens to 256.
|
||||||
|
# 2. Increase beam_width to 8.
|
||||||
|
# 3. Use the model "huggyllama/llama-7b".
|
||||||
|
MAX_TOKENS = [128]
|
||||||
|
BEAM_WIDTHS = [4]
|
||||||
|
MODELS = ["facebook/opt-125m"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||||
|
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||||
|
def test_beam_search_single_input(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
beam_width: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||||
|
max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||||
|
max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, _ = hf_outputs[i]
|
||||||
|
vllm_output_ids, _ = vllm_outputs[i]
|
||||||
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||||
|
for j in range(len(hf_output_ids)):
|
||||||
|
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||||
|
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||||
|
f"vLLM: {vllm_output_ids}")
|
||||||
184
tests/samplers/test_sampler.py
Normal file
184
tests/samplers/test_sampler.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
import pytest
|
||||||
|
import random
|
||||||
|
from typing import Tuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
class MockLogitsSampler(Sampler):
|
||||||
|
|
||||||
|
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
||||||
|
super().__init__(vocab_size=vocab_size)
|
||||||
|
self.fake_logits = fake_logits
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||||
|
lambda x, y: x):
|
||||||
|
with patch("vllm.model_executor.layers.sampler._get_logits",
|
||||||
|
lambda *args, **kwargs: self.fake_logits):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_test(
|
||||||
|
batch_size: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
||||||
|
vocab_size = 32000
|
||||||
|
input_tensor = torch.rand((batch_size, 1024),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16)
|
||||||
|
fake_logits = torch.full((batch_size, vocab_size),
|
||||||
|
1e-2,
|
||||||
|
device=input_tensor.device,
|
||||||
|
dtype=input_tensor.dtype)
|
||||||
|
sampler = MockLogitsSampler(32000, fake_logits)
|
||||||
|
worker = Worker(None, None, None)
|
||||||
|
worker.block_size = 16
|
||||||
|
return input_tensor, fake_logits, sampler, worker
|
||||||
|
|
||||||
|
|
||||||
|
RANDOM_SEEDS = list(range(128))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_all_greedy(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(temperature=0, ),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
for nth_output in sequence_output:
|
||||||
|
assert nth_output.output_token == expected[i].item()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_all_random(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
fake_logits[i, i] = 1e2
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
n=random.randint(1, 10),
|
||||||
|
),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
for nth_output in sequence_output:
|
||||||
|
assert nth_output.output_token == i
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_all_beam(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
best_of=2,
|
||||||
|
use_beam_search=True,
|
||||||
|
),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
# no assertion here as I am not sure how to determine whether
|
||||||
|
# the outputs are expected - in other words, this just tests
|
||||||
|
# whether there are no exceptions in the sampler
|
||||||
|
# when handling an all-beam search case.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_mixed(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
expected_tokens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
n = 1
|
||||||
|
sampling_type = random.randint(0, 2)
|
||||||
|
if sampling_type == 0:
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
elif sampling_type == 1:
|
||||||
|
n = random.randint(1, 10)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=random.random() + 0.1,
|
||||||
|
top_p=min(random.random() + 0.1, 1),
|
||||||
|
top_k=random.randint(0, 10) or -1,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=random.randint(0, 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampling_params = SamplingParams(temperature=0,
|
||||||
|
use_beam_search=True,
|
||||||
|
best_of=2)
|
||||||
|
for idx in range(n):
|
||||||
|
fake_logits[i, i + idx] = 1e2
|
||||||
|
expected_tokens.append(i + idx)
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||||
|
continue
|
||||||
|
for nth_output in sequence_output:
|
||||||
|
assert nth_output.output_token in expected_tokens
|
||||||
@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
__version__ = "0.1.2"
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
|
|||||||
179
vllm/config.py
179
vllm/config.py
@@ -20,15 +20,31 @@ class ModelConfig:
|
|||||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
available, and "slow" will always use the slow tokenizer.
|
available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
download_dir: Directory to download and load the weights, default to the
|
download_dir: Directory to download and load the weights, default to the
|
||||||
default cache directory of huggingface.
|
default cache directory of huggingface.
|
||||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
load_format: The format of the model weights to load:
|
||||||
This can increase the disk usage by up to 2x.
|
"auto" will try to load the weights in the safetensors format and
|
||||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
dtype: Data type for model weights and activations. The "auto" option
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
for BF16 models.
|
for BF16 models.
|
||||||
seed: Random seed for reproducibility.
|
seed: Random seed for reproducibility.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
|
version.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
|
output). If None, will be derived from the model.
|
||||||
|
quantization: Quantization method that was used to quantize the model
|
||||||
|
weights. If None, we assume the model weights are not quantized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -36,23 +52,42 @@ class ModelConfig:
|
|||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
download_dir: Optional[str],
|
download_dir: Optional[str],
|
||||||
use_np_weights: bool,
|
load_format: str,
|
||||||
use_dummy_weights: bool,
|
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
self.trust_remote_code = trust_remote_code
|
||||||
self.download_dir = download_dir
|
self.download_dir = download_dir
|
||||||
self.use_np_weights = use_np_weights
|
self.load_format = load_format
|
||||||
self.use_dummy_weights = use_dummy_weights
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.revision = revision
|
||||||
|
self.quantization = quantization
|
||||||
|
|
||||||
self.hf_config = get_config(model)
|
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||||
|
max_model_len)
|
||||||
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
self._verify_quantization()
|
||||||
|
|
||||||
|
def _verify_load_format(self) -> None:
|
||||||
|
load_format = self.load_format.lower()
|
||||||
|
if load_format not in [
|
||||||
|
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown load format: {self.load_format}. Must be one of "
|
||||||
|
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||||
|
self.load_format = load_format
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = self.tokenizer_mode.lower()
|
tokenizer_mode = self.tokenizer_mode.lower()
|
||||||
@@ -62,6 +97,17 @@ class ModelConfig:
|
|||||||
"either 'auto' or 'slow'.")
|
"either 'auto' or 'slow'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
|
def _verify_quantization(self) -> None:
|
||||||
|
supported_quantization = ["awq"]
|
||||||
|
if self.quantization is None:
|
||||||
|
return
|
||||||
|
quantization = self.quantization.lower()
|
||||||
|
if quantization not in supported_quantization:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||||
|
f"{supported_quantization}.")
|
||||||
|
self.quantization = quantization
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
parallel_config: "ParallelConfig",
|
parallel_config: "ParallelConfig",
|
||||||
@@ -89,7 +135,32 @@ class ModelConfig:
|
|||||||
# FIXME(woosuk): This may not be true for all models.
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
"""Returns the number of KV heads per GPU worker."""
|
||||||
|
# For GPTBigCode & Falcon:
|
||||||
|
# Note: for falcon, when new_decoder_architecture is True, the
|
||||||
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
|
# KV heads.
|
||||||
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||||
|
new_decoder_arch_falcon = (
|
||||||
|
self.hf_config.model_type in falcon_model_types
|
||||||
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
|
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||||
|
"multi_query", False):
|
||||||
|
# Multi-query attention, only one KV head.
|
||||||
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
|
return 1
|
||||||
|
# For Falcon:
|
||||||
|
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||||
|
return (self.hf_config.n_head_kv //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
|
if getattr(self.hf_config, "num_kv_heads", None) is not None:
|
||||||
|
return (self.hf_config.num_kv_heads //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
|
# For LLaMA-2:
|
||||||
|
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||||
|
return (self.hf_config.num_key_value_heads //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
@@ -113,10 +184,12 @@ class CacheConfig:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
self.swap_space_bytes = swap_space * _GB
|
self.swap_space_bytes = swap_space * _GB
|
||||||
|
self.sliding_window = sliding_window
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
# Will be set after profiling.
|
# Will be set after profiling.
|
||||||
@@ -188,15 +261,40 @@ class SchedulerConfig:
|
|||||||
a single iteration.
|
a single iteration.
|
||||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||||
iteration.
|
iteration.
|
||||||
max_seq_len: Maximum length of a sequence (including prompt
|
max_model_len: Maximum length of a sequence (including prompt
|
||||||
and generated text).
|
and generated text).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
def __init__(
|
||||||
max_seq_len: int) -> None:
|
self,
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
max_num_batched_tokens: Optional[int],
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_model_len: int,
|
||||||
|
) -> None:
|
||||||
|
if max_num_batched_tokens is not None:
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
else:
|
||||||
|
# If max_model_len is too short, use 2048 as the default value for
|
||||||
|
# higher throughput.
|
||||||
|
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_seq_len = max_seq_len
|
self.max_model_len = max_model_len
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
if self.max_num_batched_tokens < self.max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||||
|
f"smaller than max_model_len ({self.max_model_len}). "
|
||||||
|
"This effectively limits the maximum sequence length to "
|
||||||
|
"max_num_batched_tokens and makes vLLM reject longer "
|
||||||
|
"sequences. Please increase max_num_batched_tokens or "
|
||||||
|
"decrease max_model_len.")
|
||||||
|
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||||
|
"be greater than or equal to max_num_seqs "
|
||||||
|
f"({self.max_num_seqs}).")
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
@@ -252,3 +350,56 @@ def _get_and_verify_dtype(
|
|||||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _get_and_verify_max_len(
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
max_model_len: Optional[int],
|
||||||
|
) -> int:
|
||||||
|
"""Get and verify the model's maximum length."""
|
||||||
|
derived_max_model_len = float("inf")
|
||||||
|
possible_keys = [
|
||||||
|
# OPT
|
||||||
|
"max_position_embeddings",
|
||||||
|
# GPT-2
|
||||||
|
"n_positions",
|
||||||
|
# MPT
|
||||||
|
"max_seq_len",
|
||||||
|
# Others
|
||||||
|
"max_sequence_length",
|
||||||
|
"max_seq_length",
|
||||||
|
"seq_len",
|
||||||
|
]
|
||||||
|
for key in possible_keys:
|
||||||
|
max_len_key = getattr(hf_config, key, None)
|
||||||
|
if max_len_key is not None:
|
||||||
|
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||||
|
if derived_max_model_len == float("inf"):
|
||||||
|
if max_model_len is not None:
|
||||||
|
# If max_model_len is specified, we use it.
|
||||||
|
return max_model_len
|
||||||
|
|
||||||
|
default_max_len = 2048
|
||||||
|
logger.warning(
|
||||||
|
"The model's config.json does not contain any of the following "
|
||||||
|
"keys to determine the original maximum length of the model: "
|
||||||
|
f"{possible_keys}. Assuming the model's maximum length is "
|
||||||
|
f"{default_max_len}.")
|
||||||
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
assert "factor" in rope_scaling
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = derived_max_model_len
|
||||||
|
elif max_model_len > derived_max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||||
|
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||||
|
" in model's config.json). This may lead to incorrect model "
|
||||||
|
"outputs or CUDA errors. Make sure the value is correct and "
|
||||||
|
"within the model context size.")
|
||||||
|
return int(max_model_len)
|
||||||
|
|||||||
@@ -63,10 +63,18 @@ class BlockSpaceManager:
|
|||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
watermark: float = 0.01,
|
watermark: float = 0.01,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_total_gpu_blocks = num_gpu_blocks
|
self.num_total_gpu_blocks = num_gpu_blocks
|
||||||
self.num_total_cpu_blocks = num_cpu_blocks
|
self.num_total_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
self.block_sliding_window = None
|
||||||
|
if sliding_window is not None:
|
||||||
|
assert sliding_window % block_size == 0, (sliding_window,
|
||||||
|
block_size)
|
||||||
|
self.block_sliding_window = sliding_window // block_size
|
||||||
|
|
||||||
self.watermark = watermark
|
self.watermark = watermark
|
||||||
assert watermark >= 0.0
|
assert watermark >= 0.0
|
||||||
|
|
||||||
@@ -83,6 +91,9 @@ class BlockSpaceManager:
|
|||||||
# the same prompt. This may not be true for preempted sequences.
|
# the same prompt. This may not be true for preempted sequences.
|
||||||
seq = seq_group.get_seqs()[0]
|
seq = seq_group.get_seqs()[0]
|
||||||
num_required_blocks = len(seq.logical_token_blocks)
|
num_required_blocks = len(seq.logical_token_blocks)
|
||||||
|
if self.block_sliding_window is not None:
|
||||||
|
num_required_blocks = min(num_required_blocks,
|
||||||
|
self.block_sliding_window)
|
||||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
# Use watermark to avoid frequent cache eviction.
|
# Use watermark to avoid frequent cache eviction.
|
||||||
return (num_free_gpu_blocks - num_required_blocks >=
|
return (num_free_gpu_blocks - num_required_blocks >=
|
||||||
@@ -95,8 +106,12 @@ class BlockSpaceManager:
|
|||||||
|
|
||||||
# Allocate new physical token blocks that will store the prompt tokens.
|
# Allocate new physical token blocks that will store the prompt tokens.
|
||||||
block_table: BlockTable = []
|
block_table: BlockTable = []
|
||||||
for _ in range(len(seq.logical_token_blocks)):
|
for logical_idx in range(len(seq.logical_token_blocks)):
|
||||||
block = self.gpu_allocator.allocate()
|
if (self.block_sliding_window is not None
|
||||||
|
and logical_idx >= self.block_sliding_window):
|
||||||
|
block = block_table[logical_idx % self.block_sliding_window]
|
||||||
|
else:
|
||||||
|
block = self.gpu_allocator.allocate()
|
||||||
# Set the reference counts of the token blocks.
|
# Set the reference counts of the token blocks.
|
||||||
block.ref_count = seq_group.num_seqs()
|
block.ref_count = seq_group.num_seqs()
|
||||||
block_table.append(block)
|
block_table.append(block)
|
||||||
@@ -118,11 +133,17 @@ class BlockSpaceManager:
|
|||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
if len(block_table) < len(logical_blocks):
|
if len(block_table) < len(logical_blocks):
|
||||||
# The sequence has a new logical block.
|
if (self.block_sliding_window
|
||||||
# Allocate a new physical block.
|
and len(block_table) >= self.block_sliding_window):
|
||||||
block = self.gpu_allocator.allocate()
|
# re-use a block
|
||||||
block_table.append(block)
|
block_table.append(block_table[len(block_table) %
|
||||||
return None
|
self.block_sliding_window])
|
||||||
|
else:
|
||||||
|
# The sequence has a new logical block.
|
||||||
|
# Allocate a new physical block.
|
||||||
|
block = self.gpu_allocator.allocate()
|
||||||
|
block_table.append(block)
|
||||||
|
return None
|
||||||
|
|
||||||
# We want to append the token to the last physical block.
|
# We want to append the token to the last physical block.
|
||||||
last_block = block_table[-1]
|
last_block = block_table[-1]
|
||||||
@@ -154,9 +175,7 @@ class BlockSpaceManager:
|
|||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
block_table = self.block_tables[seq.seq_id]
|
blocks.update(self.block_tables[seq.seq_id])
|
||||||
for block in block_table:
|
|
||||||
blocks.add(block)
|
|
||||||
return list(blocks)
|
return list(blocks)
|
||||||
|
|
||||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||||
@@ -172,9 +191,7 @@ class BlockSpaceManager:
|
|||||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
if seq.is_finished():
|
|
||||||
continue
|
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
@@ -203,9 +220,7 @@ class BlockSpaceManager:
|
|||||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||||
# GPU block -> CPU block.
|
# GPU block -> CPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
if seq.is_finished():
|
|
||||||
continue
|
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
@@ -228,7 +243,7 @@ class BlockSpaceManager:
|
|||||||
return block_number_mapping
|
return block_number_mapping
|
||||||
|
|
||||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||||
for block in block_table:
|
for block in set(block_table):
|
||||||
if block.device == Device.GPU:
|
if block.device == Device.GPU:
|
||||||
self.gpu_allocator.free(block)
|
self.gpu_allocator.free(block)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
from vllm.core.block_manager import BlockSpaceManager
|
from vllm.core.block_manager import BlockSpaceManager
|
||||||
from vllm.core.policy import PolicyFactory
|
from vllm.core.policy import PolicyFactory
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceOutputs,
|
SequenceGroupMetadata, SequenceStatus)
|
||||||
SequenceStatus)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_LOGGING_INTERVAL_SEC = 5
|
|
||||||
|
|
||||||
|
|
||||||
class PreemptionMode(enum.Enum):
|
class PreemptionMode(enum.Enum):
|
||||||
"""Preemption modes.
|
"""Preemption modes.
|
||||||
@@ -32,19 +29,28 @@ class SchedulerOutputs:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
scheduled_seq_groups: List[SequenceGroup],
|
||||||
|
prompt_run: bool,
|
||||||
|
num_batched_tokens: int,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
ignored_seq_groups: List[SequenceGroup],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.scheduled_seq_groups = scheduled_seq_groups
|
||||||
|
self.prompt_run = prompt_run
|
||||||
|
self.num_batched_tokens = num_batched_tokens
|
||||||
self.blocks_to_swap_in = blocks_to_swap_in
|
self.blocks_to_swap_in = blocks_to_swap_in
|
||||||
self.blocks_to_swap_out = blocks_to_swap_out
|
self.blocks_to_swap_out = blocks_to_swap_out
|
||||||
self.blocks_to_copy = blocks_to_copy
|
self.blocks_to_copy = blocks_to_copy
|
||||||
# Swap in and swap out should never happen at the same time.
|
# Swap in and swap out should never happen at the same time.
|
||||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||||
|
self.ignored_seq_groups = ignored_seq_groups
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
|
# NOTE: We do not consider the ignored sequence groups.
|
||||||
and not self.blocks_to_copy)
|
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||||
|
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
@@ -53,11 +59,12 @@ class Scheduler:
|
|||||||
self,
|
self,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
log_stats: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.log_stats = log_stats
|
|
||||||
|
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
|
|
||||||
# Instantiate the scheduling policy.
|
# Instantiate the scheduling policy.
|
||||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
@@ -66,8 +73,9 @@ class Scheduler:
|
|||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||||
)
|
sliding_window=self.cache_config.sliding_window)
|
||||||
|
|
||||||
|
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||||
# Sequence groups in the WAITING state.
|
# Sequence groups in the WAITING state.
|
||||||
self.waiting: List[SequenceGroup] = []
|
self.waiting: List[SequenceGroup] = []
|
||||||
# Sequence groups in the RUNNING state.
|
# Sequence groups in the RUNNING state.
|
||||||
@@ -75,25 +83,30 @@ class Scheduler:
|
|||||||
# Sequence groups in the SWAPPED state.
|
# Sequence groups in the SWAPPED state.
|
||||||
self.swapped: List[SequenceGroup] = []
|
self.swapped: List[SequenceGroup] = []
|
||||||
|
|
||||||
self.last_logging_time: float = 0.0
|
|
||||||
# List[timestamp, num_tokens]
|
|
||||||
self.num_input_tokens: List[Tuple[float, int]] = []
|
|
||||||
|
|
||||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
|
|
||||||
def abort_seq_group(self, request_id: str) -> None:
|
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
|
if isinstance(request_id, str):
|
||||||
|
request_id = (request_id, )
|
||||||
|
request_ids = set(request_id)
|
||||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||||
for seq_group in state_queue:
|
# We need to reverse the list as we are removing elements
|
||||||
if seq_group.request_id == request_id:
|
# from it as we iterate over it. If we don't do it,
|
||||||
|
# indices will get messed up and we will skip over elements.
|
||||||
|
for seq_group in reversed(state_queue):
|
||||||
|
if seq_group.request_id in request_ids:
|
||||||
# Remove the sequence group from the state queue.
|
# Remove the sequence group from the state queue.
|
||||||
state_queue.remove(seq_group)
|
state_queue.remove(seq_group)
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.get_seqs():
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||||
return
|
self.free_seq(seq)
|
||||||
|
request_ids.remove(seq_group.request_id)
|
||||||
|
if not request_ids:
|
||||||
|
return
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
return self.waiting or self.running or self.swapped
|
return self.waiting or self.running or self.swapped
|
||||||
@@ -101,21 +114,81 @@ class Scheduler:
|
|||||||
def get_num_unfinished_seq_groups(self) -> int:
|
def get_num_unfinished_seq_groups(self) -> int:
|
||||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||||
|
|
||||||
def _schedule(
|
def _schedule(self) -> SchedulerOutputs:
|
||||||
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
|
|
||||||
# Blocks that need to be swaped or copied before model execution.
|
# Blocks that need to be swaped or copied before model execution.
|
||||||
blocks_to_swap_in: Dict[int, int] = {}
|
blocks_to_swap_in: Dict[int, int] = {}
|
||||||
blocks_to_swap_out: Dict[int, int] = {}
|
blocks_to_swap_out: Dict[int, int] = {}
|
||||||
blocks_to_copy: Dict[int, List[int]] = {}
|
blocks_to_copy: Dict[int, List[int]] = {}
|
||||||
ignored_seq_groups: List[SequenceGroup] = []
|
|
||||||
|
|
||||||
# Fix the current time.
|
# Fix the current time.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
|
# Join waiting sequences if possible.
|
||||||
# in order to minimize the preemption overheads.
|
if not self.swapped:
|
||||||
# Preemption happens only when there is no available slot to keep all
|
ignored_seq_groups: List[SequenceGroup] = []
|
||||||
# the sequence groups in the RUNNING state.
|
scheduled: List[SequenceGroup] = []
|
||||||
|
# The total number of sequences on the fly, including the
|
||||||
|
# requests in the generation phase.
|
||||||
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
|
for seq_group in self.running)
|
||||||
|
num_batched_tokens = 0
|
||||||
|
# Optimization: We do not sort the waiting queue since the preempted
|
||||||
|
# sequence groups are added to the front and the new sequence groups
|
||||||
|
# are added to the back.
|
||||||
|
while self.waiting:
|
||||||
|
seq_group = self.waiting[0]
|
||||||
|
|
||||||
|
assert seq_group.num_seqs() == 1, (
|
||||||
|
"Waiting sequence group should have only one prompt "
|
||||||
|
"sequence.")
|
||||||
|
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||||
|
if num_prompt_tokens > self.prompt_limit:
|
||||||
|
logger.warning(
|
||||||
|
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||||
|
f" and exceeds limit of {self.prompt_limit}")
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
|
ignored_seq_groups.append(seq_group)
|
||||||
|
self.waiting.pop(0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the sequence group cannot be allocated, stop.
|
||||||
|
if not self.block_manager.can_allocate(seq_group):
|
||||||
|
break
|
||||||
|
|
||||||
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
|
if (num_batched_tokens + num_prompt_tokens >
|
||||||
|
self.scheduler_config.max_num_batched_tokens):
|
||||||
|
break
|
||||||
|
|
||||||
|
# The total number of sequences in the RUNNING state should not
|
||||||
|
# exceed the maximum number of sequences.
|
||||||
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
|
if (num_curr_seqs + num_new_seqs >
|
||||||
|
self.scheduler_config.max_num_seqs):
|
||||||
|
break
|
||||||
|
|
||||||
|
seq_group = self.waiting.pop(0)
|
||||||
|
self._allocate(seq_group)
|
||||||
|
self.running.append(seq_group)
|
||||||
|
num_batched_tokens += num_prompt_tokens
|
||||||
|
num_curr_seqs += num_new_seqs
|
||||||
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
|
if scheduled or ignored_seq_groups:
|
||||||
|
scheduler_outputs = SchedulerOutputs(
|
||||||
|
scheduled_seq_groups=scheduled,
|
||||||
|
prompt_run=True,
|
||||||
|
num_batched_tokens=num_batched_tokens,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
ignored_seq_groups=ignored_seq_groups,
|
||||||
|
)
|
||||||
|
return scheduler_outputs
|
||||||
|
|
||||||
|
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||||
|
# to keep all the sequence groups in the RUNNING state.
|
||||||
# In this case, the policy is responsible for deciding which sequence
|
# In this case, the policy is responsible for deciding which sequence
|
||||||
# groups to preempt.
|
# groups to preempt.
|
||||||
self.running = self.policy.sort_by_priority(now, self.running)
|
self.running = self.policy.sort_by_priority(now, self.running)
|
||||||
@@ -145,150 +218,56 @@ class Scheduler:
|
|||||||
|
|
||||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
# Swap in the sequence groups in the SWAPPED state if possible.
|
||||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||||
while self.swapped and not blocks_to_swap_out:
|
if not preempted:
|
||||||
seq_group = self.swapped[0]
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
# If the sequence group has been preempted in this step, stop.
|
for seq_group in self.running)
|
||||||
if seq_group in preempted:
|
|
||||||
break
|
|
||||||
# If the sequence group cannot be swapped in, stop.
|
|
||||||
if not self.block_manager.can_swap_in(seq_group):
|
|
||||||
break
|
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
while self.swapped:
|
||||||
# exceed the maximum number of sequences.
|
seq_group = self.swapped[0]
|
||||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
# If the sequence group cannot be swapped in, stop.
|
||||||
num_curr_seqs = sum(
|
if not self.block_manager.can_swap_in(seq_group):
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running)
|
|
||||||
if (num_curr_seqs + num_new_seqs >
|
|
||||||
self.scheduler_config.max_num_seqs):
|
|
||||||
break
|
|
||||||
|
|
||||||
seq_group = self.swapped.pop(0)
|
|
||||||
self._swap_in(seq_group, blocks_to_swap_in)
|
|
||||||
self._append_slot(seq_group, blocks_to_copy)
|
|
||||||
self.running.append(seq_group)
|
|
||||||
|
|
||||||
num_batched_tokens = sum(
|
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running)
|
|
||||||
|
|
||||||
# Join waiting sequences if possible.
|
|
||||||
prompt_group_ids: List[str] = []
|
|
||||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
|
||||||
# prioritized over the sequence groups in the WAITING state.
|
|
||||||
# This is because we want to bound the amount of CPU memory taken by
|
|
||||||
# the swapped sequence groups.
|
|
||||||
if not self.swapped:
|
|
||||||
# Optimization: We do not sort the waiting queue since the preempted
|
|
||||||
# sequence groups are added to the front and the new sequence groups
|
|
||||||
# are added to the back.
|
|
||||||
while self.waiting:
|
|
||||||
seq_group = self.waiting[0]
|
|
||||||
# If the sequence group has been preempted in this step, stop.
|
|
||||||
if seq_group in preempted:
|
|
||||||
break
|
|
||||||
|
|
||||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
|
||||||
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
|
|
||||||
logger.warning(
|
|
||||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
|
||||||
" and exceeds limit of "
|
|
||||||
f"{self.scheduler_config.max_seq_len}")
|
|
||||||
for seq in seq_group.get_seqs():
|
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
|
||||||
ignored_seq_groups.append(seq_group)
|
|
||||||
self.waiting.pop(0)
|
|
||||||
break
|
|
||||||
|
|
||||||
# If the sequence group cannot be allocated, stop.
|
|
||||||
if not self.block_manager.can_allocate(seq_group):
|
|
||||||
break
|
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
|
||||||
if (num_batched_tokens + num_prompt_tokens >
|
|
||||||
self.scheduler_config.max_num_batched_tokens):
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
# The total number of sequences in the RUNNING state should not
|
||||||
# exceed the maximum number of sequences.
|
# exceed the maximum number of sequences.
|
||||||
num_new_seqs = seq_group.num_seqs(
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
status=SequenceStatus.WAITING)
|
|
||||||
num_curr_seqs = sum(
|
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running)
|
|
||||||
if (num_curr_seqs + num_new_seqs >
|
if (num_curr_seqs + num_new_seqs >
|
||||||
self.scheduler_config.max_num_seqs):
|
self.scheduler_config.max_num_seqs):
|
||||||
break
|
break
|
||||||
|
|
||||||
seq_group = self.waiting.pop(0)
|
seq_group = self.swapped.pop(0)
|
||||||
self._allocate(seq_group)
|
self._swap_in(seq_group, blocks_to_swap_in)
|
||||||
|
self._append_slot(seq_group, blocks_to_copy)
|
||||||
|
num_curr_seqs += num_new_seqs
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
num_batched_tokens += num_prompt_tokens
|
|
||||||
prompt_group_ids.append(seq_group.request_id)
|
# Each sequence in the generation phase only takes one token slot.
|
||||||
|
# Therefore, the number of batched tokens is equal to the number of
|
||||||
|
# sequences in the RUNNING state.
|
||||||
|
num_batched_tokens = sum(
|
||||||
|
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
for seq_group in self.running)
|
||||||
|
|
||||||
scheduler_outputs = SchedulerOutputs(
|
scheduler_outputs = SchedulerOutputs(
|
||||||
|
scheduled_seq_groups=self.running,
|
||||||
|
prompt_run=False,
|
||||||
|
num_batched_tokens=num_batched_tokens,
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
ignored_seq_groups=[],
|
||||||
)
|
)
|
||||||
if not self.log_stats:
|
return scheduler_outputs
|
||||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
|
||||||
|
|
||||||
# TODO(woosuk): Move the below code to the engine.
|
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||||
now = time.time()
|
|
||||||
if num_batched_tokens > 0:
|
|
||||||
self.num_input_tokens.append((now, num_batched_tokens))
|
|
||||||
elapsed_time = now - self.last_logging_time
|
|
||||||
if elapsed_time > _LOGGING_INTERVAL_SEC:
|
|
||||||
self.last_logging_time = now
|
|
||||||
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
|
|
||||||
if now - t < _LOGGING_INTERVAL_SEC]
|
|
||||||
if len(self.num_input_tokens) > 1:
|
|
||||||
total_num_tokens = sum(n
|
|
||||||
for _, n in self.num_input_tokens[:-1])
|
|
||||||
window = now - self.num_input_tokens[0][0]
|
|
||||||
avg_throughput = total_num_tokens / window
|
|
||||||
else:
|
|
||||||
avg_throughput = 0.0
|
|
||||||
|
|
||||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
|
||||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
|
||||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
|
||||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
|
||||||
|
|
||||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
|
||||||
if total_num_cpu_blocks > 0:
|
|
||||||
num_free_cpu_blocks = (
|
|
||||||
self.block_manager.get_num_free_cpu_blocks())
|
|
||||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
|
||||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
|
||||||
else:
|
|
||||||
cpu_cache_usage = 0.0
|
|
||||||
|
|
||||||
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
|
|
||||||
f"Running: {len(self.running)} reqs, "
|
|
||||||
f"Swapped: {len(self.swapped)} reqs, "
|
|
||||||
f"Pending: {len(self.waiting)} reqs, "
|
|
||||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
|
||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
|
||||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
|
||||||
|
|
||||||
def schedule(
|
|
||||||
self
|
|
||||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
|
||||||
List[SequenceGroup]]:
|
|
||||||
# Schedule sequence groups.
|
# Schedule sequence groups.
|
||||||
# This function call changes the internal states of the scheduler
|
# This function call changes the internal states of the scheduler
|
||||||
# such as self.running, self.swapped, and self.waiting.
|
# such as self.running, self.swapped, and self.waiting.
|
||||||
(scheduler_outputs, prompt_group_ids,
|
scheduler_outputs = self._schedule()
|
||||||
ignored_seq_groups) = self._schedule()
|
|
||||||
|
|
||||||
# Create input data structures.
|
# Create input data structures.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
for seq_group in self.running:
|
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||||
is_prompt = seq_group.request_id in prompt_group_ids
|
|
||||||
|
|
||||||
seq_data: Dict[int, List[SequenceData]] = {}
|
seq_data: Dict[int, List[SequenceData]] = {}
|
||||||
block_tables: Dict[int, List[int]] = {}
|
block_tables: Dict[int, List[int]] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
@@ -298,43 +277,18 @@ class Scheduler:
|
|||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=seq_group.request_id,
|
request_id=seq_group.request_id,
|
||||||
is_prompt=is_prompt,
|
is_prompt=scheduler_outputs.prompt_run,
|
||||||
seq_data=seq_data,
|
seq_data=seq_data,
|
||||||
sampling_params=seq_group.sampling_params,
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
|
return seq_group_metadata_list, scheduler_outputs
|
||||||
|
|
||||||
def update(
|
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||||
self,
|
self.block_manager.fork(parent_seq, child_seq)
|
||||||
seq_outputs: Dict[int, SequenceOutputs],
|
|
||||||
) -> List[SequenceGroup]:
|
|
||||||
# Update the running sequences and free blocks.
|
|
||||||
for seq_group in self.running:
|
|
||||||
# Process beam search results before processing the new tokens.
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
output = seq_outputs[seq.seq_id]
|
|
||||||
if seq.seq_id != output.parent_seq_id:
|
|
||||||
# The sequence is a fork of the parent sequence (beam
|
|
||||||
# search). Free the current sequence.
|
|
||||||
self.block_manager.free(seq)
|
|
||||||
# Fork the parent sequence.
|
|
||||||
parent_seq = seq_group.find(output.parent_seq_id)
|
|
||||||
parent_seq.fork(seq)
|
|
||||||
self.block_manager.fork(parent_seq, seq)
|
|
||||||
|
|
||||||
# Process the new tokens.
|
def free_seq(self, seq: Sequence) -> None:
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
# Append a new token to the sequence.
|
|
||||||
output = seq_outputs[seq.seq_id]
|
|
||||||
seq.append_token_id(output.output_token, output.logprobs)
|
|
||||||
# Return a shallow copy of the running queue to prevent the queue
|
|
||||||
# from being modified by the caller.
|
|
||||||
return self.running.copy()
|
|
||||||
|
|
||||||
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
|
||||||
seq.status = finish_status
|
|
||||||
self.block_manager.free(seq)
|
self.block_manager.free(seq)
|
||||||
|
|
||||||
def free_finished_seq_groups(self) -> None:
|
def free_finished_seq_groups(self) -> None:
|
||||||
@@ -371,8 +325,8 @@ class Scheduler:
|
|||||||
# If preemption mode is not specified, we determine the mode as follows:
|
# If preemption mode is not specified, we determine the mode as follows:
|
||||||
# We use recomputation by default since it incurs lower overhead than
|
# We use recomputation by default since it incurs lower overhead than
|
||||||
# swapping. However, when the sequence group has multiple sequences
|
# swapping. However, when the sequence group has multiple sequences
|
||||||
# (e.g., beam search), recomputation is not supported. In such a case,
|
# (e.g., beam search), recomputation is not currently supported. In
|
||||||
# we use swapping instead.
|
# such a case, we use swapping instead.
|
||||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||||
# As swapped sequences are prioritized over waiting sequences,
|
# As swapped sequences are prioritized over waiting sequences,
|
||||||
# sequence groups with multiple sequences are implicitly prioritized
|
# sequence groups with multiple sequences are implicitly prioritized
|
||||||
@@ -380,8 +334,7 @@ class Scheduler:
|
|||||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||||
# sequences. This may require a more sophisticated CUDA kernel.
|
# sequences. This may require a more sophisticated CUDA kernel.
|
||||||
if preemption_mode is None:
|
if preemption_mode is None:
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
if seq_group.get_max_num_running_seqs() == 1:
|
||||||
if len(seqs) == 1:
|
|
||||||
preemption_mode = PreemptionMode.RECOMPUTE
|
preemption_mode = PreemptionMode.RECOMPUTE
|
||||||
else:
|
else:
|
||||||
preemption_mode = PreemptionMode.SWAP
|
preemption_mode = PreemptionMode.SWAP
|
||||||
@@ -410,9 +363,6 @@ class Scheduler:
|
|||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq in seqs:
|
|
||||||
seq.status = SequenceStatus.SWAPPED
|
|
||||||
self._swap_out(seq_group, blocks_to_swap_out)
|
self._swap_out(seq_group, blocks_to_swap_out)
|
||||||
self.swapped.append(seq_group)
|
self.swapped.append(seq_group)
|
||||||
|
|
||||||
|
|||||||
@@ -13,25 +13,27 @@ class EngineArgs:
|
|||||||
model: str
|
model: str
|
||||||
tokenizer: Optional[str] = None
|
tokenizer: Optional[str] = None
|
||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
|
trust_remote_code: bool = False
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
use_np_weights: bool = False
|
load_format: str = 'auto'
|
||||||
use_dummy_weights: bool = False
|
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = 1
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
swap_space: int = 4 # GiB
|
swap_space: int = 4 # GiB
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: int = 2560
|
max_num_batched_tokens: Optional[int] = None
|
||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
self.tokenizer = self.model
|
self.tokenizer = self.model
|
||||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -48,6 +50,13 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='name or path of the huggingface tokenizer to use')
|
help='name or path of the huggingface tokenizer to use')
|
||||||
|
parser.add_argument(
|
||||||
|
'--revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific model version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
parser.add_argument('--tokenizer-mode',
|
parser.add_argument('--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer_mode,
|
default=EngineArgs.tokenizer_mode,
|
||||||
@@ -55,30 +64,46 @@ class EngineArgs:
|
|||||||
help='tokenizer mode. "auto" will use the fast '
|
help='tokenizer mode. "auto" will use the fast '
|
||||||
'tokenizer if available, and "slow" will '
|
'tokenizer if available, and "slow" will '
|
||||||
'always use the slow tokenizer.')
|
'always use the slow tokenizer.')
|
||||||
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
parser.add_argument('--download-dir',
|
parser.add_argument('--download-dir',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.download_dir,
|
default=EngineArgs.download_dir,
|
||||||
help='directory to download and load the weights, '
|
help='directory to download and load the weights, '
|
||||||
'default to the default cache dir of '
|
'default to the default cache dir of '
|
||||||
'huggingface')
|
'huggingface')
|
||||||
parser.add_argument('--use-np-weights',
|
parser.add_argument(
|
||||||
action='store_true',
|
'--load-format',
|
||||||
help='save a numpy copy of model weights for '
|
type=str,
|
||||||
'faster loading. This can increase the disk '
|
default=EngineArgs.load_format,
|
||||||
'usage by up to 2x.')
|
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||||
parser.add_argument('--use-dummy-weights',
|
help='The format of the model weights to load. '
|
||||||
action='store_true',
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
help='use dummy values for model weights')
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
# TODO(woosuk): Support FP32.
|
'is not available. '
|
||||||
|
'"pt" will load the weights in the pytorch bin format. '
|
||||||
|
'"safetensors" will load the weights in the safetensors format. '
|
||||||
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
|
'a numpy cache to speed up the loading. '
|
||||||
|
'"dummy" will initialize the weights with random values, '
|
||||||
|
'which is mainly for profiling.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.dtype,
|
default=EngineArgs.dtype,
|
||||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
choices=[
|
||||||
|
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||||
|
],
|
||||||
help='data type for model weights and activations. '
|
help='data type for model weights and activations. '
|
||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='model context length. If unspecified, '
|
||||||
|
'will be automatically derived from the model.')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument('--worker-use-ray',
|
parser.add_argument('--worker-use-ray',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -126,6 +151,13 @@ class EngineArgs:
|
|||||||
parser.add_argument('--disable-log-stats',
|
parser.add_argument('--disable-log-stats',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='disable logging statistics')
|
help='disable logging statistics')
|
||||||
|
# Quantization settings.
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None,
|
||||||
|
help='Method used to quantize the weights')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -139,22 +171,20 @@ class EngineArgs:
|
|||||||
def create_engine_configs(
|
def create_engine_configs(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||||
# Initialize the configs.
|
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.download_dir,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.use_np_weights, self.use_dummy_weights,
|
self.download_dir, self.load_format,
|
||||||
self.dtype, self.seed)
|
self.dtype, self.seed, self.revision,
|
||||||
cache_config = CacheConfig(self.block_size,
|
self.max_model_len, self.quantization)
|
||||||
self.gpu_memory_utilization,
|
cache_config = CacheConfig(
|
||||||
self.swap_space)
|
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
||||||
|
getattr(model_config.hf_config, 'sliding_window', None))
|
||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
model_max_len = getattr(model_config.hf_config,
|
|
||||||
'max_position_embeddings', float('inf'))
|
|
||||||
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
|
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs, max_seq_len)
|
self.max_num_seqs,
|
||||||
|
model_config.max_model_len)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
@@ -163,6 +193,7 @@ class AsyncEngineArgs(EngineArgs):
|
|||||||
"""Arguments for asynchronous vLLM engine."""
|
"""Arguments for asynchronous vLLM engine."""
|
||||||
engine_use_ray: bool = False
|
engine_use_ray: bool = False
|
||||||
disable_log_requests: bool = False
|
disable_log_requests: bool = False
|
||||||
|
max_log_len: Optional[int] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -175,4 +206,10 @@ class AsyncEngineArgs(EngineArgs):
|
|||||||
parser.add_argument('--disable-log-requests',
|
parser.add_argument('--disable-log-requests',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='disable logging requests')
|
help='disable logging requests')
|
||||||
|
parser.add_argument('--max-log-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='max number of prompt characters or prompt '
|
||||||
|
'ID numbers being printed in log. '
|
||||||
|
'Default: unlimited.')
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from functools import partial
|
||||||
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||||
|
Union)
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@@ -12,7 +14,219 @@ from vllm.sampling_params import SamplingParams
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|
||||||
|
class AsyncEngineDeadError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_exception_on_finish(task: asyncio.Task,
|
||||||
|
request_tracker: "RequestTracker") -> None:
|
||||||
|
msg = ("Task finished unexpectedly. This should never happen! "
|
||||||
|
"Please open an issue on Github.")
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
raise AsyncEngineDeadError(
|
||||||
|
msg + " See stack trace above for the actual cause.") from exc
|
||||||
|
raise AsyncEngineDeadError(msg)
|
||||||
|
except Exception as exc:
|
||||||
|
request_tracker.propagate_exception(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncStream:
|
||||||
|
"""A stream of RequestOutputs for a request that can be
|
||||||
|
iterated over asynchronously."""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._finished = False
|
||||||
|
|
||||||
|
def put(self, item: RequestOutput) -> None:
|
||||||
|
if self._finished:
|
||||||
|
return
|
||||||
|
self._queue.put_nowait(item)
|
||||||
|
|
||||||
|
def finish(self) -> None:
|
||||||
|
self._queue.put_nowait(StopIteration)
|
||||||
|
self._finished = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finished(self) -> bool:
|
||||||
|
return self._finished
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> RequestOutput:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is StopIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RequestTracker:
|
||||||
|
"""Synchronous abstraction for tracking requests."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._request_streams: Dict[str, AsyncStream] = {}
|
||||||
|
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||||
|
dict]] = asyncio.Queue()
|
||||||
|
self.new_requests_event = None
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self._request_streams
|
||||||
|
|
||||||
|
def init_event(self):
|
||||||
|
self.new_requests_event = asyncio.Event()
|
||||||
|
|
||||||
|
def propagate_exception(self,
|
||||||
|
exc: Exception,
|
||||||
|
request_id: Optional[str] = None) -> None:
|
||||||
|
"""Propagate an exception to request streams
|
||||||
|
(all if request_id is None)."""
|
||||||
|
if request_id is not None:
|
||||||
|
self._request_streams[request_id].put(exc)
|
||||||
|
else:
|
||||||
|
for stream in self._request_streams.values():
|
||||||
|
stream.put(exc)
|
||||||
|
|
||||||
|
def process_request_output(self,
|
||||||
|
request_output: RequestOutput,
|
||||||
|
*,
|
||||||
|
verbose: bool = False) -> None:
|
||||||
|
"""Process a request output from the engine."""
|
||||||
|
request_id = request_output.request_id
|
||||||
|
|
||||||
|
self._request_streams[request_id].put(request_output)
|
||||||
|
if request_output.finished:
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Finished request {request_id}.")
|
||||||
|
self.abort_request(request_id)
|
||||||
|
|
||||||
|
def add_request(self, request_id: str,
|
||||||
|
**engine_add_request_kwargs) -> AsyncStream:
|
||||||
|
"""Add a request to be sent to the engine on the next background
|
||||||
|
loop iteration."""
|
||||||
|
if request_id in self._request_streams:
|
||||||
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
|
|
||||||
|
stream = AsyncStream(request_id)
|
||||||
|
self._new_requests.put_nowait((stream, {
|
||||||
|
"request_id": request_id,
|
||||||
|
**engine_add_request_kwargs
|
||||||
|
}))
|
||||||
|
|
||||||
|
self.new_requests_event.set()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||||
|
"""Abort a request during next background loop iteration."""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
|
self._finished_requests.put_nowait(request_id)
|
||||||
|
|
||||||
|
if request_id not in self._request_streams or self._request_streams[
|
||||||
|
request_id].finished:
|
||||||
|
# The request has already finished or been aborted.
|
||||||
|
return
|
||||||
|
|
||||||
|
self._request_streams[request_id].finish()
|
||||||
|
|
||||||
|
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||||
|
"""Get the new requests and finished requests to be
|
||||||
|
sent to the engine."""
|
||||||
|
new_requests: List[dict] = []
|
||||||
|
finished_requests: Set[str] = set()
|
||||||
|
|
||||||
|
while not self._finished_requests.empty():
|
||||||
|
request_id = self._finished_requests.get_nowait()
|
||||||
|
finished_requests.add(request_id)
|
||||||
|
self._request_streams.pop(request_id, None)
|
||||||
|
|
||||||
|
while not self._new_requests.empty():
|
||||||
|
stream, new_request = self._new_requests.get_nowait()
|
||||||
|
if stream.request_id in finished_requests:
|
||||||
|
# The request has already been aborted.
|
||||||
|
stream.finish()
|
||||||
|
continue
|
||||||
|
self._request_streams[stream.request_id] = stream
|
||||||
|
new_requests.append(new_request)
|
||||||
|
|
||||||
|
self.new_requests_event.clear()
|
||||||
|
|
||||||
|
return new_requests, finished_requests
|
||||||
|
|
||||||
|
async def wait_for_new_requests(self):
|
||||||
|
await self.new_requests_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
|
"""Extension of LLMEngine to add async methods."""
|
||||||
|
|
||||||
|
async def step_async(self) -> List[RequestOutput]:
|
||||||
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
The workers are ran asynchronously if possible.
|
||||||
|
|
||||||
|
This function performs one decoding iteration of the engine. It first
|
||||||
|
schedules the sequences to be executed in the next iteration and the
|
||||||
|
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||||
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
|
the sequences and returns the newly generated results.
|
||||||
|
"""
|
||||||
|
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||||
|
if scheduler_outputs.is_empty():
|
||||||
|
return ignored
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
output = await self._run_workers_async(
|
||||||
|
"execute_model",
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||||
|
|
||||||
|
async def _run_workers_async(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
get_all_outputs: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers."""
|
||||||
|
all_outputs = []
|
||||||
|
for worker in self.workers:
|
||||||
|
if self.parallel_config.worker_use_ray:
|
||||||
|
executor = partial(worker.execute_method.remote, method)
|
||||||
|
else:
|
||||||
|
executor = getattr(worker, method)
|
||||||
|
|
||||||
|
output = executor(*args, **kwargs)
|
||||||
|
all_outputs.append(output)
|
||||||
|
|
||||||
|
if self.parallel_config.worker_use_ray:
|
||||||
|
all_outputs = await asyncio.gather(*all_outputs)
|
||||||
|
|
||||||
|
if get_all_outputs:
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
# Make sure all workers have the same results.
|
||||||
|
output = all_outputs[0]
|
||||||
|
for other_output in all_outputs[1:]:
|
||||||
|
assert output == other_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMEngine:
|
class AsyncLLMEngine:
|
||||||
@@ -34,52 +248,149 @@ class AsyncLLMEngine:
|
|||||||
async frontend will be executed in a separate process as the
|
async frontend will be executed in a separate process as the
|
||||||
model workers.
|
model workers.
|
||||||
log_requests: Whether to log the requests.
|
log_requests: Whether to log the requests.
|
||||||
|
start_engine_loop: If True, the background task to run the engine
|
||||||
|
will be automatically started in the generate call.
|
||||||
*args, *kwargs: Arguments for LLMEngine.
|
*args, *kwargs: Arguments for LLMEngine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
worker_use_ray: bool,
|
worker_use_ray: bool,
|
||||||
engine_use_ray: bool,
|
engine_use_ray: bool,
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
|
max_log_len: Optional[int] = None,
|
||||||
|
start_engine_loop: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.engine_use_ray = engine_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
if not self.engine_use_ray:
|
self.max_log_len = max_log_len
|
||||||
engine_class = LLMEngine
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
elif self.worker_use_ray:
|
|
||||||
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
self.background_loop = None
|
||||||
else:
|
# We need to keep a reference to unshielded
|
||||||
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
# task as well to prevent it from being garbage
|
||||||
self.engine = engine_class(*args, **kwargs)
|
# collected
|
||||||
# Request id -> request output.
|
self._background_loop_unshielded = None
|
||||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
self.start_engine_loop = start_engine_loop
|
||||||
# Request id -> event to notify that there is new output.
|
self._request_tracker = RequestTracker()
|
||||||
self.request_events: Dict[str, asyncio.Event] = {}
|
|
||||||
self.is_engine_running = False
|
@property
|
||||||
self.kicking_request_id: Optional[str] = None
|
def is_running(self) -> bool:
|
||||||
|
return (self.background_loop is not None
|
||||||
|
and not self.background_loop.done())
|
||||||
|
|
||||||
|
def start_background_loop(self) -> None:
|
||||||
|
"""Start the background loop."""
|
||||||
|
if self.is_running:
|
||||||
|
raise RuntimeError("Background loop is already running.")
|
||||||
|
self._request_tracker.init_event()
|
||||||
|
|
||||||
|
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||||
|
).create_task(self.run_engine_loop())
|
||||||
|
self._background_loop_unshielded.add_done_callback(
|
||||||
|
partial(_raise_exception_on_finish,
|
||||||
|
request_tracker=self._request_tracker))
|
||||||
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
|
|
||||||
|
def _init_engine(self, *args,
|
||||||
|
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||||
|
if not self.engine_use_ray:
|
||||||
|
engine_class = self._engine_class
|
||||||
|
elif self.worker_use_ray:
|
||||||
|
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||||
|
else:
|
||||||
|
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||||
|
return engine_class(*args, **kwargs)
|
||||||
|
|
||||||
|
async def engine_step(self) -> bool:
|
||||||
|
"""Kick the engine to process the waiting requests.
|
||||||
|
|
||||||
|
Returns True if there are in-progress requests."""
|
||||||
|
|
||||||
|
new_requests, finished_requests = (
|
||||||
|
self._request_tracker.get_new_and_finished_requests())
|
||||||
|
|
||||||
|
for new_request in new_requests:
|
||||||
|
# Add the request into the vLLM engine's waiting queue.
|
||||||
|
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||||
|
if self.engine_use_ray:
|
||||||
|
await self.engine.add_request.remote(**new_request)
|
||||||
|
else:
|
||||||
|
self.engine.add_request(**new_request)
|
||||||
|
|
||||||
|
if finished_requests:
|
||||||
|
await self._engine_abort(finished_requests)
|
||||||
|
|
||||||
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
|
||||||
"""Kick the engine to process the waiting requests."""
|
|
||||||
self.is_engine_running = True
|
|
||||||
self.kicking_request_id = kicking_request_id
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
request_outputs = await self.engine.step.remote()
|
request_outputs = await self.engine.step.remote()
|
||||||
else:
|
else:
|
||||||
# Yield to the event loop to allow other coroutines to run
|
request_outputs = await self.engine.step_async()
|
||||||
# while is_engine_running is True. This let the engine to add new
|
|
||||||
# requests into the queue.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
request_outputs = self.engine.step()
|
|
||||||
self.is_engine_running = False
|
|
||||||
self.kicking_request_id = None
|
|
||||||
|
|
||||||
# Notify the waiting coroutines that there are new outputs ready.
|
# Put the outputs into the corresponding streams.
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
request_id = request_output.request_id
|
self._request_tracker.process_request_output(
|
||||||
self.request_outputs[request_id] = request_output
|
request_output, verbose=self.log_requests)
|
||||||
self.request_events[request_id].set()
|
|
||||||
|
return len(request_outputs) > 0
|
||||||
|
|
||||||
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
|
if self.engine_use_ray:
|
||||||
|
await self.engine.abort_request.remote(request_ids)
|
||||||
|
else:
|
||||||
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
|
async def run_engine_loop(self):
|
||||||
|
# Initialize the RequestTracker here so it uses the right event loop.
|
||||||
|
has_requests_in_progress = False
|
||||||
|
while True:
|
||||||
|
if not has_requests_in_progress:
|
||||||
|
await self._request_tracker.wait_for_new_requests()
|
||||||
|
has_requests_in_progress = await self.engine_step()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
) -> AsyncStream:
|
||||||
|
if self.log_requests:
|
||||||
|
shortened_prompt = prompt
|
||||||
|
shortened_token_ids = prompt_token_ids
|
||||||
|
if self.max_log_len is not None:
|
||||||
|
if shortened_prompt is not None:
|
||||||
|
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||||
|
if shortened_token_ids is not None:
|
||||||
|
shortened_token_ids = shortened_token_ids[:self.
|
||||||
|
max_log_len]
|
||||||
|
logger.info(f"Received request {request_id}: "
|
||||||
|
f"prompt: {shortened_prompt!r}, "
|
||||||
|
f"sampling params: {sampling_params}, "
|
||||||
|
f"prompt token ids: {shortened_token_ids}.")
|
||||||
|
|
||||||
|
if not self.is_running:
|
||||||
|
if self.start_engine_loop:
|
||||||
|
self.start_background_loop()
|
||||||
|
else:
|
||||||
|
raise AsyncEngineDeadError(
|
||||||
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
|
stream = self._request_tracker.add_request(
|
||||||
|
request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@@ -108,76 +419,20 @@ class AsyncLLMEngine:
|
|||||||
# Preprocess the request.
|
# Preprocess the request.
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
# Create an event to notify us that there is new output from the
|
try:
|
||||||
# vLLM engine.
|
stream = await self.add_request(request_id,
|
||||||
request_event = asyncio.Event()
|
prompt,
|
||||||
self.request_events[request_id] = request_event
|
sampling_params,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
if self.log_requests:
|
async for request_output in stream:
|
||||||
logger.info(f"Received request {request_id}: "
|
yield request_output
|
||||||
f"prompt: {prompt!r}, "
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
f"sampling params: {sampling_params}, "
|
# If there is an exception or coroutine is cancelled, abort the
|
||||||
f"prompt token ids: {prompt_token_ids}.")
|
# request.
|
||||||
|
self._abort(request_id)
|
||||||
# Add the request into the vLLM engine's waiting queue.
|
raise e
|
||||||
if self.engine_use_ray:
|
|
||||||
await self.engine.add_request.remote(
|
|
||||||
request_id,
|
|
||||||
prompt,
|
|
||||||
sampling_params,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
arrival_time=arrival_time)
|
|
||||||
else:
|
|
||||||
self.engine.add_request(request_id,
|
|
||||||
prompt,
|
|
||||||
sampling_params,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
arrival_time=arrival_time)
|
|
||||||
|
|
||||||
# The vLLM engine does not have a background loop that keeps
|
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
|
||||||
# the engine to process the requests.
|
|
||||||
while True:
|
|
||||||
if request_id not in self.request_events:
|
|
||||||
# The request has been aborted.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Kick the engine if the engine is not running.
|
|
||||||
if not self.is_engine_running:
|
|
||||||
try:
|
|
||||||
await self.engine_step(request_id)
|
|
||||||
except RuntimeError as e:
|
|
||||||
await self.abort(request_id)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Wait for new output. The group_event will be set in engine_step
|
|
||||||
# when there is new output available for the sequence group.
|
|
||||||
# Added a timeout to prevent deadlock.
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(request_event.wait(),
|
|
||||||
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
# Reset the event to wait for the next output.
|
|
||||||
request_event.clear()
|
|
||||||
|
|
||||||
# Decode and return new outputs.
|
|
||||||
request_output = self.request_outputs[request_id]
|
|
||||||
yield request_output
|
|
||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
|
||||||
if request_output.finished:
|
|
||||||
if self.log_requests:
|
|
||||||
logger.info(f"Finished request {request_id}.")
|
|
||||||
|
|
||||||
del self.request_outputs[request_id]
|
|
||||||
del self.request_events[request_id]
|
|
||||||
# Kick the engine if the engine is not running. This is to
|
|
||||||
# prevent that there are still requests in engine's waiting
|
|
||||||
# queue to be executed.
|
|
||||||
if not self.is_engine_running:
|
|
||||||
await self.engine_step()
|
|
||||||
break
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort a request.
|
"""Abort a request.
|
||||||
@@ -188,28 +443,26 @@ class AsyncLLMEngine:
|
|||||||
Args:
|
Args:
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
if request_id not in self.request_events:
|
if not self.is_running:
|
||||||
# The request has already finished or been aborted.
|
raise AsyncEngineDeadError(
|
||||||
return
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
if self.log_requests:
|
return self._abort(request_id)
|
||||||
logger.info(f"Aborted request {request_id}.")
|
|
||||||
|
|
||||||
if self.engine_use_ray:
|
def _abort(self, request_id: str) -> None:
|
||||||
await self.engine.abort_request.remote(request_id)
|
"""Abort a request.
|
||||||
else:
|
|
||||||
self.engine.abort_request(request_id)
|
|
||||||
|
|
||||||
if request_id in self.request_events:
|
Abort a submitted request. If the request is finished or not found,
|
||||||
del self.request_events[request_id]
|
this method will be a no-op.
|
||||||
if request_id in self.request_outputs:
|
|
||||||
del self.request_outputs[request_id]
|
|
||||||
|
|
||||||
# To prevent deadlock when a request is aborted while the engine is
|
Args:
|
||||||
# running.
|
request_id: The unique id of the request.
|
||||||
if self.kicking_request_id == request_id:
|
"""
|
||||||
self.is_engine_running = False
|
self._request_tracker.abort_request(request_id,
|
||||||
self.kicking_request_id = None
|
verbose=self.log_requests)
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
@@ -220,20 +473,23 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls,
|
def from_engine_args(cls,
|
||||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
engine_args: AsyncEngineArgs,
|
||||||
|
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
parallel_config = engine_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
# Initialize the cluster.
|
# Initialize the cluster.
|
||||||
distributed_init_method, devices = initialize_cluster(
|
distributed_init_method, placement_group = initialize_cluster(
|
||||||
parallel_config, engine_args.engine_use_ray)
|
parallel_config, engine_args.engine_use_ray)
|
||||||
# Create the async LLM engine.
|
# Create the async LLM engine.
|
||||||
engine = cls(engine_args.worker_use_ray,
|
engine = cls(engine_args.worker_use_ray,
|
||||||
engine_args.engine_use_ray,
|
engine_args.engine_use_ray,
|
||||||
*engine_configs,
|
*engine_configs,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
devices,
|
placement_group,
|
||||||
log_requests=not engine_args.disable_log_requests,
|
log_requests=not engine_args.disable_log_requests,
|
||||||
log_stats=not engine_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
max_log_len=engine_args.max_log_len,
|
||||||
|
start_engine_loop=start_engine_loop)
|
||||||
return engine
|
return engine
|
||||||
|
|||||||
@@ -1,22 +1,34 @@
|
|||||||
|
import copy
|
||||||
import time
|
import time
|
||||||
from typing import Any, List, Optional
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
|
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||||
|
SequenceGroupMetadata, SequenceOutputs,
|
||||||
|
SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
get_tokenizer)
|
get_tokenizer)
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
from vllm.worker.worker import Worker
|
|
||||||
|
if ray:
|
||||||
|
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM engine that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
@@ -42,8 +54,8 @@ class LLMEngine:
|
|||||||
scheduler_config: The configuration related to the request scheduler.
|
scheduler_config: The configuration related to the request scheduler.
|
||||||
distributed_init_method: The initialization method for distributed
|
distributed_init_method: The initialization method for distributed
|
||||||
execution. See `torch.distributed.init_process_group` for details.
|
execution. See `torch.distributed.init_process_group` for details.
|
||||||
stage_devices: The list of devices for each stage. Each stage is a list
|
placement_group: Ray placement group for distributed execution.
|
||||||
of (rank, node_resource, device) tuples.
|
Required for distributed execution.
|
||||||
log_stats: Whether to log statistics.
|
log_stats: Whether to log statistics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -54,7 +66,7 @@ class LLMEngine:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
stage_devices: List[List[DeviceID]],
|
placement_group: Optional["PlacementGroup"],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -62,50 +74,112 @@ class LLMEngine:
|
|||||||
f"model={model_config.model!r}, "
|
f"model={model_config.model!r}, "
|
||||||
f"tokenizer={model_config.tokenizer!r}, "
|
f"tokenizer={model_config.tokenizer!r}, "
|
||||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
|
f"revision={model_config.revision}, "
|
||||||
|
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
f"max_seq_len={model_config.max_model_len}, "
|
||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
f"use_np_weights={model_config.use_np_weights}, "
|
f"load_format={model_config.load_format}, "
|
||||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
|
f"quantization={model_config.quantization}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
assert self.cache_config.sliding_window == getattr(
|
||||||
|
self.model_config.hf_config, "sliding_window", None)
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
|
model_config.tokenizer,
|
||||||
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
revision=model_config.revision)
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
self.workers: List[Worker] = []
|
if self.parallel_config.worker_use_ray:
|
||||||
assert len(stage_devices) == 1, "Only support one stage for now."
|
self._init_workers_ray(placement_group)
|
||||||
for rank, node_resource, _ in stage_devices[0]:
|
else:
|
||||||
worker_cls = Worker
|
self._init_workers(distributed_init_method)
|
||||||
if self.parallel_config.worker_use_ray:
|
|
||||||
worker_cls = ray.remote(
|
|
||||||
num_cpus=0,
|
|
||||||
num_gpus=1,
|
|
||||||
resources={node_resource: 1e-3},
|
|
||||||
)(worker_cls).remote
|
|
||||||
|
|
||||||
worker = worker_cls(
|
|
||||||
model_config,
|
|
||||||
parallel_config,
|
|
||||||
scheduler_config,
|
|
||||||
rank,
|
|
||||||
distributed_init_method,
|
|
||||||
)
|
|
||||||
self.workers.append(worker)
|
|
||||||
# Profile the memory usage and initialize the cache.
|
# Profile the memory usage and initialize the cache.
|
||||||
self._init_cache()
|
self._init_cache()
|
||||||
|
|
||||||
# Create the scheduler.
|
# Create the scheduler.
|
||||||
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
self.scheduler = Scheduler(scheduler_config, cache_config)
|
||||||
|
|
||||||
|
# Logging.
|
||||||
|
self.last_logging_time = 0.0
|
||||||
|
# List of (timestamp, num_tokens)
|
||||||
|
self.num_prompt_tokens: List[Tuple[float, int]] = []
|
||||||
|
# List of (timestamp, num_tokens)
|
||||||
|
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||||
|
|
||||||
|
def _init_workers(self, distributed_init_method: str):
|
||||||
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
|
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
assert self.parallel_config.world_size == 1, (
|
||||||
|
"Ray is required if parallel_config.world_size > 1.")
|
||||||
|
|
||||||
|
self.workers: List[Worker] = []
|
||||||
|
worker = Worker(
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
0,
|
||||||
|
distributed_init_method,
|
||||||
|
)
|
||||||
|
self.workers.append(worker)
|
||||||
|
self._run_workers(
|
||||||
|
"init_model",
|
||||||
|
get_all_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
|
**ray_remote_kwargs):
|
||||||
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
|
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
self.workers: List[Worker] = []
|
||||||
|
for bundle in placement_group.bundle_specs:
|
||||||
|
if not bundle.get("GPU", 0):
|
||||||
|
continue
|
||||||
|
worker = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=1,
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_capture_child_tasks=True),
|
||||||
|
**ray_remote_kwargs,
|
||||||
|
)(RayWorker).remote(self.model_config.trust_remote_code)
|
||||||
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
# Initialize torch distributed process group for the workers.
|
||||||
|
init_torch_dist_process_group(self.workers, backend="nccl")
|
||||||
|
model_config = copy.deepcopy(self.model_config)
|
||||||
|
parallel_config = copy.deepcopy(self.parallel_config)
|
||||||
|
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||||
|
self._run_workers("init_worker",
|
||||||
|
get_all_outputs=True,
|
||||||
|
worker_init_fn=lambda: Worker(
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
self._run_workers(
|
||||||
|
"init_model",
|
||||||
|
get_all_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
@@ -149,11 +223,12 @@ class LLMEngine:
|
|||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
parallel_config = engine_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
# Initialize the cluster.
|
# Initialize the cluster.
|
||||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
distributed_init_method, placement_group = initialize_cluster(
|
||||||
|
parallel_config)
|
||||||
# Create the LLM engine.
|
# Create the LLM engine.
|
||||||
engine = cls(*engine_configs,
|
engine = cls(*engine_configs,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
devices,
|
placement_group,
|
||||||
log_stats=not engine_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats)
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
@@ -189,24 +264,21 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
seqs: List[Sequence] = []
|
seq_id = next(self.seq_counter)
|
||||||
for _ in range(sampling_params.best_of):
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||||
seq_id = next(self.seq_counter)
|
|
||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
arrival_time)
|
arrival_time)
|
||||||
|
|
||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
def abort_request(self, request_id: str) -> None:
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
"""Aborts a request with the given ID.
|
"""Aborts a request(s) with the given ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The ID of the request to abort.
|
request_id: The ID(s) of the request to abort.
|
||||||
"""
|
"""
|
||||||
self.scheduler.abort_seq_group(request_id)
|
self.scheduler.abort_seq_group(request_id)
|
||||||
|
|
||||||
@@ -222,6 +294,249 @@ class LLMEngine:
|
|||||||
"""Returns True if there are unfinished requests."""
|
"""Returns True if there are unfinished requests."""
|
||||||
return self.scheduler.has_unfinished_seqs()
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
|
def _schedule(
|
||||||
|
self
|
||||||
|
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||||
|
List[RequestOutput]]:
|
||||||
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
|
return seq_group_metadata_list, scheduler_outputs, [
|
||||||
|
RequestOutput.from_seq_group(seq_group)
|
||||||
|
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
def _check_beam_search_early_stopping(
|
||||||
|
self,
|
||||||
|
early_stopping: Union[bool, str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
best_running_seq: Sequence,
|
||||||
|
current_worst_seq: Sequence,
|
||||||
|
) -> bool:
|
||||||
|
assert sampling_params.use_beam_search
|
||||||
|
length_penalty = sampling_params.length_penalty
|
||||||
|
if early_stopping is True:
|
||||||
|
return True
|
||||||
|
|
||||||
|
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id))
|
||||||
|
if early_stopping is False:
|
||||||
|
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id))
|
||||||
|
else:
|
||||||
|
assert early_stopping == "never"
|
||||||
|
if length_penalty > 0.0:
|
||||||
|
# If length_penalty > 0.0, beam search will prefer longer
|
||||||
|
# sequences. The highest attainable score calculation is
|
||||||
|
# based on the longest possible sequence length in this case.
|
||||||
|
max_possible_length = max(
|
||||||
|
best_running_seq.get_prompt_len() +
|
||||||
|
sampling_params.max_tokens,
|
||||||
|
self.scheduler_config.max_model_len)
|
||||||
|
highest_attainable_score = (
|
||||||
|
best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
seq_len=max_possible_length))
|
||||||
|
else:
|
||||||
|
# Otherwise, beam search will prefer shorter sequences. The
|
||||||
|
# highest attainable score calculation is based on the current
|
||||||
|
# sequence length.
|
||||||
|
highest_attainable_score = (
|
||||||
|
best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id))
|
||||||
|
return current_worst_score >= highest_attainable_score
|
||||||
|
|
||||||
|
def _process_sequence_group_samples(
|
||||||
|
self, seq_group: SequenceGroup,
|
||||||
|
samples: List[SequenceOutputs]) -> None:
|
||||||
|
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||||
|
parent_child_dict = {
|
||||||
|
parent_seq.seq_id: []
|
||||||
|
for parent_seq in parent_seqs
|
||||||
|
}
|
||||||
|
for sample in samples:
|
||||||
|
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||||
|
# List of (child, parent)
|
||||||
|
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||||
|
|
||||||
|
# Process the child samples for each parent sequence
|
||||||
|
for parent in parent_seqs:
|
||||||
|
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||||
|
parent.seq_id]
|
||||||
|
if len(child_samples) == 0:
|
||||||
|
# This parent sequence has no children samples. Remove
|
||||||
|
# the parent sequence from the sequence group since it will
|
||||||
|
# not be used in the future iterations.
|
||||||
|
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||||
|
seq_group.remove(parent.seq_id)
|
||||||
|
self.scheduler.free_seq(parent)
|
||||||
|
continue
|
||||||
|
# Fork the parent sequence if there are multiple child samples.
|
||||||
|
for child_sample in child_samples[:-1]:
|
||||||
|
new_child_seq_id = next(self.seq_counter)
|
||||||
|
child = parent.fork(new_child_seq_id)
|
||||||
|
child.append_token_id(child_sample.output_token,
|
||||||
|
child_sample.logprobs)
|
||||||
|
child_seqs.append((child, parent))
|
||||||
|
# Continue the parent sequence for the last child sample.
|
||||||
|
# We reuse the parent sequence here to reduce redundant memory
|
||||||
|
# copies, especially when using non-beam search sampling methods.
|
||||||
|
last_child_sample = child_samples[-1]
|
||||||
|
parent.append_token_id(last_child_sample.output_token,
|
||||||
|
last_child_sample.logprobs)
|
||||||
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
|
for seq, _ in child_seqs:
|
||||||
|
self._decode_sequence(seq, seq_group.sampling_params)
|
||||||
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
|
# Non-beam search case
|
||||||
|
if not seq_group.sampling_params.use_beam_search:
|
||||||
|
# For newly created child sequences, add them to the sequence group
|
||||||
|
# and fork them in block manager if they are not finished.
|
||||||
|
for seq, parent in child_seqs:
|
||||||
|
if seq is not parent:
|
||||||
|
seq_group.add(seq)
|
||||||
|
if not seq.is_finished():
|
||||||
|
self.scheduler.fork_seq(parent, seq)
|
||||||
|
|
||||||
|
# Free the finished and selected parent sequences' memory in block
|
||||||
|
# manager. Keep them in the sequence group as candidate output.
|
||||||
|
# NOTE: we need to fork the new sequences before freeing the
|
||||||
|
# old sequences.
|
||||||
|
for seq, parent in child_seqs:
|
||||||
|
if seq is parent and seq.is_finished():
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Beam search case
|
||||||
|
# Select the child sequences to keep in the sequence group.
|
||||||
|
selected_child_seqs = []
|
||||||
|
unselected_child_seqs = []
|
||||||
|
beam_width = seq_group.sampling_params.best_of
|
||||||
|
length_penalty = seq_group.sampling_params.length_penalty
|
||||||
|
|
||||||
|
# Select the newly finished sequences with the highest scores
|
||||||
|
# to replace existing finished sequences.
|
||||||
|
# Tuple of (seq, parent, is_new)
|
||||||
|
existing_finished_seqs = [(seq, None, False)
|
||||||
|
for seq in existing_finished_seqs]
|
||||||
|
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||||
|
if seq.is_finished()]
|
||||||
|
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||||
|
# Sort the finished sequences by their scores.
|
||||||
|
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id),
|
||||||
|
reverse=True)
|
||||||
|
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||||
|
if is_new:
|
||||||
|
# A newly generated child sequence finishes and has a high
|
||||||
|
# score, so we will add it into the sequence group.
|
||||||
|
selected_child_seqs.append((seq, parent))
|
||||||
|
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||||
|
if is_new:
|
||||||
|
# A newly generated child sequence finishes but has a low
|
||||||
|
# score, so we will not add it into the sequence group.
|
||||||
|
# Additionally, if this sequence is a continuation of a
|
||||||
|
# parent sequence, we will need remove the parent sequence
|
||||||
|
# from the sequence group.
|
||||||
|
unselected_child_seqs.append((seq, parent))
|
||||||
|
else:
|
||||||
|
# An existing finished sequence has a low score, so we will
|
||||||
|
# remove it from the sequence group.
|
||||||
|
seq_group.remove(seq.seq_id)
|
||||||
|
|
||||||
|
# select the top beam_width sequences from the running
|
||||||
|
# sequences for the next iteration to continue the beam
|
||||||
|
# search.
|
||||||
|
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||||
|
if not seq.is_finished()]
|
||||||
|
# Sort the running sequences by their scores.
|
||||||
|
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# Check if we can stop the beam search.
|
||||||
|
if len(running_child_seqs) == 0:
|
||||||
|
# No running sequences, stop the beam search.
|
||||||
|
stop_beam_search = True
|
||||||
|
elif len(all_finished_seqs) < beam_width:
|
||||||
|
# Not enough finished sequences, continue the beam search.
|
||||||
|
stop_beam_search = False
|
||||||
|
else:
|
||||||
|
# Check the early stopping criteria
|
||||||
|
best_running_seq = running_child_seqs[0][0]
|
||||||
|
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||||
|
stop_beam_search = self._check_beam_search_early_stopping(
|
||||||
|
seq_group.sampling_params.early_stopping,
|
||||||
|
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||||
|
|
||||||
|
if stop_beam_search:
|
||||||
|
# Stop the beam search and remove all the running sequences from
|
||||||
|
# the sequence group.
|
||||||
|
unselected_child_seqs.extend(running_child_seqs)
|
||||||
|
else:
|
||||||
|
# Continue the beam search and select the top beam_width sequences
|
||||||
|
# to continue the beam search.
|
||||||
|
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||||
|
# The remaining running sequences will not be used in the next
|
||||||
|
# iteration. Again, if these sequences are continuations of
|
||||||
|
# parent sequences, we will need to remove the parent sequences
|
||||||
|
# from the sequence group.
|
||||||
|
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||||
|
|
||||||
|
# For newly created child sequences, add them to the sequence group
|
||||||
|
# and fork them in block manager if they are not finished.
|
||||||
|
for seq, parent in selected_child_seqs:
|
||||||
|
if seq is not parent:
|
||||||
|
seq_group.add(seq)
|
||||||
|
if not seq.is_finished():
|
||||||
|
self.scheduler.fork_seq(parent, seq)
|
||||||
|
|
||||||
|
# Free the finished and selected parent sequences' memory in block
|
||||||
|
# manager. Keep them in the sequence group as candidate output.
|
||||||
|
for seq, parent in selected_child_seqs:
|
||||||
|
if seq is parent and seq.is_finished():
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
|
||||||
|
# Remove the unselected parent sequences from the sequence group and
|
||||||
|
# free their memory in block manager.
|
||||||
|
for seq, parent in unselected_child_seqs:
|
||||||
|
if seq is parent:
|
||||||
|
# Remove the parent sequence if it is not selected for next
|
||||||
|
# iteration
|
||||||
|
seq_group.remove(seq.seq_id)
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
|
||||||
|
def _process_model_outputs(
|
||||||
|
self, output: SamplerOutput,
|
||||||
|
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||||
|
# Update the scheduled sequence groups with the model outputs.
|
||||||
|
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||||
|
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||||
|
self._process_sequence_group_samples(seq_group, samples)
|
||||||
|
|
||||||
|
# Free the finished sequence groups.
|
||||||
|
self.scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
|
# Create the outputs.
|
||||||
|
request_outputs: List[RequestOutput] = []
|
||||||
|
for seq_group in (scheduled_seq_groups +
|
||||||
|
scheduler_outputs.ignored_seq_groups):
|
||||||
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
|
request_outputs.append(request_output)
|
||||||
|
|
||||||
|
if self.log_stats:
|
||||||
|
# Log the system stats.
|
||||||
|
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||||
|
scheduler_outputs.num_batched_tokens)
|
||||||
|
return request_outputs
|
||||||
|
|
||||||
def step(self) -> List[RequestOutput]:
|
def step(self) -> List[RequestOutput]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
@@ -231,12 +546,9 @@ class LLMEngine:
|
|||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
the sequences and returns the newly generated results.
|
the sequences and returns the newly generated results.
|
||||||
"""
|
"""
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||||
ignored_seq_groups) = self.scheduler.schedule()
|
if scheduler_outputs.is_empty():
|
||||||
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
|
return ignored
|
||||||
and (not ignored_seq_groups)):
|
|
||||||
# Nothing to do.
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = self._run_workers(
|
output = self._run_workers(
|
||||||
@@ -246,72 +558,121 @@ class LLMEngine:
|
|||||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
)
|
)
|
||||||
# Update the scheduler with the model outputs.
|
|
||||||
seq_groups = self.scheduler.update(output)
|
|
||||||
|
|
||||||
# Decode the sequences.
|
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||||
self._decode_sequences(seq_groups)
|
|
||||||
# Stop the sequences that meet the stopping criteria.
|
|
||||||
self._stop_sequences(seq_groups)
|
|
||||||
# Free the finished sequence groups.
|
|
||||||
self.scheduler.free_finished_seq_groups()
|
|
||||||
|
|
||||||
# Create the outputs.
|
def _log_system_stats(
|
||||||
request_outputs: List[RequestOutput] = []
|
self,
|
||||||
for seq_group in seq_groups + ignored_seq_groups:
|
prompt_run: bool,
|
||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
num_batched_tokens: int,
|
||||||
request_outputs.append(request_output)
|
) -> None:
|
||||||
return request_outputs
|
now = time.time()
|
||||||
|
# Log the number of batched input tokens.
|
||||||
|
if prompt_run:
|
||||||
|
self.num_prompt_tokens.append((now, num_batched_tokens))
|
||||||
|
else:
|
||||||
|
self.num_generation_tokens.append((now, num_batched_tokens))
|
||||||
|
|
||||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
elapsed_time = now - self.last_logging_time
|
||||||
"""Decodes the sequence outputs."""
|
if elapsed_time < _LOGGING_INTERVAL_SEC:
|
||||||
for seq_group in seq_groups:
|
return
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
new_token, new_output_text = detokenize_incrementally(
|
|
||||||
self.tokenizer,
|
|
||||||
seq.output_tokens,
|
|
||||||
seq.get_last_token_id(),
|
|
||||||
skip_special_tokens=True,
|
|
||||||
)
|
|
||||||
seq.output_tokens.append(new_token)
|
|
||||||
seq.output_text = new_output_text
|
|
||||||
|
|
||||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
# Discard the old stats.
|
||||||
|
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
|
||||||
|
if now - t < _LOGGING_INTERVAL_SEC]
|
||||||
|
self.num_generation_tokens = [(t, n)
|
||||||
|
for t, n in self.num_generation_tokens
|
||||||
|
if now - t < _LOGGING_INTERVAL_SEC]
|
||||||
|
|
||||||
|
if len(self.num_prompt_tokens) > 1:
|
||||||
|
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
|
||||||
|
window = now - self.num_prompt_tokens[0][0]
|
||||||
|
avg_prompt_throughput = total_num_tokens / window
|
||||||
|
else:
|
||||||
|
avg_prompt_throughput = 0.0
|
||||||
|
if len(self.num_generation_tokens) > 1:
|
||||||
|
total_num_tokens = sum(n
|
||||||
|
for _, n in self.num_generation_tokens[:-1])
|
||||||
|
window = now - self.num_generation_tokens[0][0]
|
||||||
|
avg_generation_throughput = total_num_tokens / window
|
||||||
|
else:
|
||||||
|
avg_generation_throughput = 0.0
|
||||||
|
|
||||||
|
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
|
num_free_gpu_blocks = (
|
||||||
|
self.scheduler.block_manager.get_num_free_gpu_blocks())
|
||||||
|
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||||
|
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||||
|
|
||||||
|
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||||
|
if total_num_cpu_blocks > 0:
|
||||||
|
num_free_cpu_blocks = (
|
||||||
|
self.scheduler.block_manager.get_num_free_cpu_blocks())
|
||||||
|
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||||
|
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||||
|
else:
|
||||||
|
cpu_cache_usage = 0.0
|
||||||
|
|
||||||
|
logger.info("Avg prompt throughput: "
|
||||||
|
f"{avg_prompt_throughput:.1f} tokens/s, "
|
||||||
|
"Avg generation throughput: "
|
||||||
|
f"{avg_generation_throughput:.1f} tokens/s, "
|
||||||
|
f"Running: {len(self.scheduler.running)} reqs, "
|
||||||
|
f"Swapped: {len(self.scheduler.swapped)} reqs, "
|
||||||
|
f"Pending: {len(self.scheduler.waiting)} reqs, "
|
||||||
|
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||||
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
|
self.last_logging_time = now
|
||||||
|
|
||||||
|
def _decode_sequence(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
|
"""Decodes the new token for a sequence."""
|
||||||
|
(new_tokens, new_output_text, prefix_offset,
|
||||||
|
read_offset) = detokenize_incrementally(
|
||||||
|
self.tokenizer,
|
||||||
|
all_input_ids=seq.get_token_ids(),
|
||||||
|
prev_tokens=seq.tokens,
|
||||||
|
prefix_offset=seq.prefix_offset,
|
||||||
|
read_offset=seq.read_offset,
|
||||||
|
skip_special_tokens=sampling_params.skip_special_tokens,
|
||||||
|
)
|
||||||
|
if seq.tokens is None:
|
||||||
|
seq.tokens = new_tokens
|
||||||
|
else:
|
||||||
|
seq.tokens.extend(new_tokens)
|
||||||
|
seq.prefix_offset = prefix_offset
|
||||||
|
seq.read_offset = read_offset
|
||||||
|
seq.output_text += new_output_text
|
||||||
|
|
||||||
|
def _check_stop(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
for seq_group in seq_groups:
|
for stop_str in sampling_params.stop:
|
||||||
sampling_params = seq_group.sampling_params
|
if seq.output_text.endswith(stop_str):
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
# Truncate the output text so that the stop string is
|
||||||
# Check if the sequence has generated a stop string.
|
# not included in the output.
|
||||||
stopped = False
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
for stop_str in sampling_params.stop:
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
if seq.output_text.endswith(stop_str):
|
return
|
||||||
# Truncate the output text so that the stop string is
|
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||||
# not included in the output.
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
return
|
||||||
self.scheduler.free_seq(
|
|
||||||
seq, SequenceStatus.FINISHED_STOPPED)
|
|
||||||
stopped = True
|
|
||||||
break
|
|
||||||
if stopped:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if the sequence has reached max_seq_len.
|
# Check if the sequence has reached max_model_len.
|
||||||
if (seq.get_len() >=
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
self.scheduler.scheduler_config.max_seq_len):
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
self.scheduler.free_seq(
|
return
|
||||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
|
||||||
continue
|
# Check if the sequence has reached max_tokens.
|
||||||
# Check if the sequence has reached max_tokens.
|
if seq.get_output_len() == sampling_params.max_tokens:
|
||||||
if seq.get_output_len() == sampling_params.max_tokens:
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
self.scheduler.free_seq(
|
return
|
||||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
|
||||||
continue
|
# Check if the sequence has generated the EOS token.
|
||||||
# Check if the sequence has generated the EOS token.
|
if ((not sampling_params.ignore_eos)
|
||||||
if not sampling_params.ignore_eos:
|
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
||||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
self.scheduler.free_seq(
|
return
|
||||||
seq, SequenceStatus.FINISHED_STOPPED)
|
|
||||||
continue
|
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
@@ -323,9 +684,10 @@ class LLMEngine:
|
|||||||
"""Runs the given method on all workers."""
|
"""Runs the given method on all workers."""
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
executor = getattr(worker, method)
|
|
||||||
if self.parallel_config.worker_use_ray:
|
if self.parallel_config.worker_use_ray:
|
||||||
executor = executor.remote
|
executor = partial(worker.execute_method.remote, method)
|
||||||
|
else:
|
||||||
|
executor = getattr(worker, method)
|
||||||
|
|
||||||
output = executor(*args, **kwargs)
|
output = executor(*args, **kwargs)
|
||||||
all_outputs.append(output)
|
all_outputs.append(output)
|
||||||
|
|||||||
@@ -1,22 +1,59 @@
|
|||||||
import random
|
import socket
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
from vllm.config import ParallelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
except ImportError:
|
from ray.air.util.torch_dist import TorchDistributedWorker
|
||||||
|
|
||||||
|
class RayWorker(TorchDistributedWorker):
|
||||||
|
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||||
|
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||||
|
|
||||||
|
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||||
|
if init_cached_hf_modules:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
|
init_hf_modules()
|
||||||
|
self.worker = None
|
||||||
|
|
||||||
|
def init_worker(self, worker_init_fn):
|
||||||
|
self.worker = worker_init_fn()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.worker, name)
|
||||||
|
|
||||||
|
def execute_method(self, method, *args, **kwargs):
|
||||||
|
executor = getattr(self, method)
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||||
|
"For distributed inference, please install Ray with "
|
||||||
|
"`pip install ray pandas pyarrow`.")
|
||||||
ray = None
|
ray = None
|
||||||
|
TorchDistributedWorker = None
|
||||||
|
RayWorker = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
if TYPE_CHECKING:
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
# rank, node resource (node IP), device id
|
|
||||||
DeviceID = Tuple[int, Optional[str], int]
|
def get_open_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
def initialize_cluster(
|
def initialize_cluster(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
engine_use_ray: bool = False,
|
engine_use_ray: bool = False,
|
||||||
ray_address: Optional[str] = None,
|
ray_address: Optional[str] = None,
|
||||||
) -> Tuple[str, List[List[DeviceID]]]:
|
) -> Tuple[str, Optional["PlacementGroup"]]:
|
||||||
"""Initialize the distributed cluster probably with Ray.
|
"""Initialize the distributed cluster probably with Ray.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -26,11 +63,10 @@ def initialize_cluster(
|
|||||||
the default Ray cluster address.
|
the default Ray cluster address.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (`distributed_init_method`, `all_stage_devices`). The
|
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||||
`distributed_init_method` is the address for initializing the
|
`distributed_init_method` is the address for initializing the
|
||||||
distributed backend. `all_stage_devices` includes device IDs for
|
distributed backend. `placement_group` includes the specification
|
||||||
each worker in each pipeline stage. Each device ID is a tuple of
|
of the resources for each distributed worker.
|
||||||
(rank, node resource, device id).
|
|
||||||
"""
|
"""
|
||||||
if parallel_config.worker_use_ray or engine_use_ray:
|
if parallel_config.worker_use_ray or engine_use_ray:
|
||||||
if ray is None:
|
if ray is None:
|
||||||
@@ -38,71 +74,46 @@ def initialize_cluster(
|
|||||||
"Ray is not installed. Please install Ray to use distributed "
|
"Ray is not installed. Please install Ray to use distributed "
|
||||||
"serving.")
|
"serving.")
|
||||||
# Connect to a ray cluster.
|
# Connect to a ray cluster.
|
||||||
ray.init(address=ray_address)
|
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||||
|
|
||||||
if not parallel_config.worker_use_ray:
|
if not parallel_config.worker_use_ray:
|
||||||
# Initialize cluster locally.
|
# Initialize cluster locally.
|
||||||
port = random.randint(10000, 20000)
|
port = get_open_port()
|
||||||
# We need to setup the distributed init method to make sure
|
# We need to setup the distributed init method to make sure
|
||||||
# the distributed megatron code (e.g., get world size) works correctly.
|
# the distributed megatron code (e.g., get world size) works correctly.
|
||||||
distributed_init_method = f"tcp://localhost:{port}"
|
distributed_init_method = f"tcp://localhost:{port}"
|
||||||
all_stage_devices = [[(0, None, 0)]]
|
return distributed_init_method, None
|
||||||
return distributed_init_method, all_stage_devices
|
|
||||||
|
|
||||||
# Assume we have a uniform cluster that each node has the same number of
|
current_placement_group = ray.util.get_current_placement_group()
|
||||||
# GPUs for now.
|
if current_placement_group:
|
||||||
valid_node_resources = []
|
# We are in a placement group
|
||||||
num_devices_per_node = None
|
bundles = current_placement_group.bundle_specs
|
||||||
for node in ray.nodes():
|
# Verify that we can use the placement group.
|
||||||
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
|
gpu_bundles = 0
|
||||||
continue
|
for bundle in bundles:
|
||||||
if num_devices_per_node is None:
|
bundle_gpus = bundle.get("GPU", 0)
|
||||||
num_devices_per_node = node["Resources"]["GPU"]
|
if bundle_gpus > 1:
|
||||||
else:
|
raise ValueError(
|
||||||
assert num_devices_per_node == node["Resources"]["GPU"], (
|
"Placement group bundle cannot have more than 1 GPU.")
|
||||||
"The number of GPUs per node is not uniform.")
|
if bundle_gpus:
|
||||||
for key in node["Resources"]:
|
gpu_bundles += 1
|
||||||
if key.startswith("node:"):
|
if parallel_config.world_size > gpu_bundles:
|
||||||
valid_node_resources.append(key)
|
|
||||||
|
|
||||||
# Verify the parallel config.
|
|
||||||
num_nodes = len(valid_node_resources)
|
|
||||||
if parallel_config.world_size > num_nodes * num_devices_per_node:
|
|
||||||
raise ValueError(
|
|
||||||
"The number of required GPUs exceeds the total number of "
|
|
||||||
"available GPUs.")
|
|
||||||
if parallel_config.tensor_parallel_size >= num_devices_per_node:
|
|
||||||
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of tensor parallelism is not divisible by the "
|
"The number of required GPUs exceeds the total number of "
|
||||||
"number of GPUs per node.")
|
"available GPUs in the placement group.")
|
||||||
else:
|
else:
|
||||||
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
|
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
|
||||||
|
if parallel_config.world_size > num_gpus_in_cluster:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of GPUs per node is not divisible by the number "
|
"The number of required GPUs exceeds the total number of "
|
||||||
"of tensor parallelism.")
|
"available GPUs in the cluster.")
|
||||||
|
# Create a new placement group
|
||||||
|
current_placement_group = ray.util.placement_group([{
|
||||||
|
"GPU": 1
|
||||||
|
}] * parallel_config.world_size)
|
||||||
|
# Wait until PG is ready - this will block until all
|
||||||
|
# requested resources are available, and will timeout
|
||||||
|
# if they cannot be provisioned.
|
||||||
|
ray.get(current_placement_group.ready(), timeout=1800)
|
||||||
|
|
||||||
# Assign GPUs to pipeline stages.
|
return None, current_placement_group
|
||||||
rank = 0
|
|
||||||
current_node_id = 0
|
|
||||||
current_device_id = 0
|
|
||||||
distributed_init_method = None
|
|
||||||
all_stage_devices = []
|
|
||||||
|
|
||||||
for _ in range(parallel_config.pipeline_parallel_size):
|
|
||||||
stage_devices = []
|
|
||||||
for _ in range(parallel_config.tensor_parallel_size):
|
|
||||||
node_resource = valid_node_resources[current_node_id]
|
|
||||||
stage_devices.append((rank, node_resource, current_device_id))
|
|
||||||
if distributed_init_method is None:
|
|
||||||
ip = node_resource.split("node:")[-1]
|
|
||||||
port = random.randint(10000, 20000)
|
|
||||||
distributed_init_method = f"tcp://{ip}:{port}"
|
|
||||||
rank += 1
|
|
||||||
current_device_id += 1
|
|
||||||
if current_device_id >= num_devices_per_node:
|
|
||||||
current_node_id += 1
|
|
||||||
current_device_id = 0
|
|
||||||
all_stage_devices.append(stage_devices)
|
|
||||||
|
|
||||||
return distributed_init_method, all_stage_devices
|
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@@ -14,6 +14,7 @@ from vllm.utils import random_uuid
|
|||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
@@ -30,6 +31,7 @@ async def generate(request: Request) -> Response:
|
|||||||
stream = request_dict.pop("stream", False)
|
stream = request_dict.pop("stream", False)
|
||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
@@ -42,14 +44,8 @@ async def generate(request: Request) -> Response:
|
|||||||
ret = {"text": text_outputs}
|
ret = {"text": text_outputs}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
return StreamingResponse(stream_results())
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(stream_results(), background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
final_output = None
|
final_output = None
|
||||||
@@ -64,7 +60,7 @@ async def generate(request: Request) -> Response:
|
|||||||
prompt = final_output.prompt
|
prompt = final_output.prompt
|
||||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||||
ret = {"text": text_outputs}
|
ret = {"text": text_outputs}
|
||||||
return Response(content=json.dumps(ret))
|
return JSONResponse(ret)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ class LLM:
|
|||||||
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
||||||
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||||
if available, and "slow" will always use the slow tokenizer.
|
if available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
tensor_parallel_size: The number of GPUs to use for distributed
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
execution with tensor parallelism.
|
execution with tensor parallelism.
|
||||||
dtype: The data type for the model weights and activations. Currently,
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
@@ -35,7 +37,22 @@ class LLM:
|
|||||||
the `torch_dtype` attribute specified in the model config file.
|
the `torch_dtype` attribute specified in the model config file.
|
||||||
However, if the `torch_dtype` in the config is `float32`, we will
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
use `float16` instead.
|
use `float16` instead.
|
||||||
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
|
we support "awq". If None, we assume the model weights are not
|
||||||
|
quantized and use `dtype` to determine the data type of the weights.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id.
|
||||||
seed: The seed to initialize the random number generator for sampling.
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||||
|
reserve for the model weights, activations, and KV cache. Higher
|
||||||
|
values will increase the KV cache size and thus improve the model's
|
||||||
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
|
memory (OOM) errors.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
This can be used for temporarily storing the states of the requests
|
||||||
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
|
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -43,9 +60,14 @@ class LLM:
|
|||||||
model: str,
|
model: str,
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
swap_space: int = 4,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
@@ -54,9 +76,14 @@ class LLM:
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
quantization=quantization,
|
||||||
|
revision=revision,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||||
@@ -151,4 +178,8 @@ class LLM:
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
# Sort the outputs by request ID.
|
||||||
|
# This is necessary because some requests may be finished earlier than
|
||||||
|
# its previous requests.
|
||||||
|
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -3,20 +3,18 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from http import HTTPStatus
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from http import HTTPStatus
|
||||||
|
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import BackgroundTasks, Request
|
import uvicorn
|
||||||
|
from fastapi import Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle
|
from packaging import version
|
||||||
from fastchat.model.model_adapter import get_conversation_template
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@@ -33,11 +31,20 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fastchat
|
||||||
|
from fastchat.conversation import Conversation, SeparatorStyle
|
||||||
|
from fastchat.model.model_adapter import get_conversation_template
|
||||||
|
_fastchat_available = True
|
||||||
|
except ImportError:
|
||||||
|
_fastchat_available = False
|
||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
served_model = None
|
served_model = None
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
|
||||||
def create_error_response(status_code: HTTPStatus,
|
def create_error_response(status_code: HTTPStatus,
|
||||||
@@ -63,10 +70,21 @@ async def check_model(request) -> Optional[JSONResponse]:
|
|||||||
|
|
||||||
|
|
||||||
async def get_gen_prompt(request) -> str:
|
async def get_gen_prompt(request) -> str:
|
||||||
|
if not _fastchat_available:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"fastchat is not installed. Please install fastchat to use "
|
||||||
|
"the chat completion and conversation APIs: `$ pip install fschat`"
|
||||||
|
)
|
||||||
|
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
||||||
|
raise ImportError(
|
||||||
|
f"fastchat version is low. Current version: {fastchat.__version__} "
|
||||||
|
"Please upgrade fastchat to use: `$ pip install -U fschat`")
|
||||||
|
|
||||||
conv = get_conversation_template(request.model)
|
conv = get_conversation_template(request.model)
|
||||||
conv = Conversation(
|
conv = Conversation(
|
||||||
name=conv.name,
|
name=conv.name,
|
||||||
system=conv.system,
|
system_template=conv.system_template,
|
||||||
|
system_message=conv.system_message,
|
||||||
roles=conv.roles,
|
roles=conv.roles,
|
||||||
messages=list(conv.messages), # prevent in-place modification
|
messages=list(conv.messages), # prevent in-place modification
|
||||||
offset=conv.offset,
|
offset=conv.offset,
|
||||||
@@ -83,7 +101,7 @@ async def get_gen_prompt(request) -> str:
|
|||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
msg_role = message["role"]
|
msg_role = message["role"]
|
||||||
if msg_role == "system":
|
if msg_role == "system":
|
||||||
conv.system = message["content"]
|
conv.system_message = message["content"]
|
||||||
elif msg_role == "user":
|
elif msg_role == "user":
|
||||||
conv.append_message(conv.roles[0], message["content"])
|
conv.append_message(conv.roles[0], message["content"])
|
||||||
elif msg_role == "assistant":
|
elif msg_role == "assistant":
|
||||||
@@ -98,32 +116,33 @@ async def get_gen_prompt(request) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
async def check_length(request, prompt, model_config):
|
async def check_length(
|
||||||
if hasattr(model_config.hf_config, "max_sequence_length"):
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
context_len = model_config.hf_config.max_sequence_length
|
prompt: Optional[str] = None,
|
||||||
elif hasattr(model_config.hf_config, "seq_length"):
|
prompt_ids: Optional[List[int]] = None
|
||||||
context_len = model_config.hf_config.seq_length
|
) -> Tuple[List[int], Optional[JSONResponse]]:
|
||||||
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
assert (not (prompt is None and prompt_ids is None)
|
||||||
context_len = model_config.hf_config.max_position_embeddings
|
and not (prompt is not None and prompt_ids is not None)
|
||||||
elif hasattr(model_config.hf_config, "seq_length"):
|
), "Either prompt or prompt_ids should be provided."
|
||||||
context_len = model_config.hf_config.seq_length
|
if prompt_ids is not None:
|
||||||
|
input_ids = prompt_ids
|
||||||
else:
|
else:
|
||||||
context_len = 2048
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
|
||||||
input_ids = tokenizer(prompt).input_ids
|
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if token_num + request.max_tokens > context_len:
|
if request.max_tokens is None:
|
||||||
return create_error_response(
|
request.max_tokens = max_model_len - token_num
|
||||||
|
if token_num + request.max_tokens > max_model_len:
|
||||||
|
return input_ids, create_error_response(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
f"This model's maximum context length is {context_len} tokens. "
|
f"This model's maximum context length is {max_model_len} tokens. "
|
||||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||||
f"({token_num} in the messages, "
|
f"({token_num} in the messages, "
|
||||||
f"{request.max_tokens} in the completion). "
|
f"{request.max_tokens} in the completion). "
|
||||||
f"Please reduce the length of the messages or completion.",
|
f"Please reduce the length of the messages or completion.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return input_ids, None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
@@ -162,7 +181,8 @@ def create_logprobs(token_ids: List[int],
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def create_chat_completion(raw_request: Request):
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
|
raw_request: Request):
|
||||||
"""Completion API similar to OpenAI's API.
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/chat/create
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
@@ -172,20 +192,19 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
- function_call (Users should implement this by themselves)
|
- function_call (Users should implement this by themselves)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
"""
|
"""
|
||||||
request = ChatCompletionRequest(**await raw_request.json())
|
|
||||||
logger.info(f"Received chat completion request: {request}")
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
if request.logit_bias is not None:
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||||
# TODO: support logit_bias in vLLM engine.
|
# TODO: support logit_bias in vLLM engine.
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
prompt = await get_gen_prompt(request)
|
prompt = await get_gen_prompt(request)
|
||||||
error_check_ret = await check_length(request, prompt, engine_model_config)
|
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
@@ -200,19 +219,19 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
|
stop_token_ids=request.stop_token_ids,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
best_of=request.best_of,
|
best_of=request.best_of,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
|
token_ids)
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
def create_stream_response_json(
|
def create_stream_response_json(
|
||||||
index: int,
|
index: int,
|
||||||
@@ -269,23 +288,19 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
)
|
)
|
||||||
yield f"data: {response_json}\n\n"
|
yield f"data: {response_json}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
background_tasks = BackgroundTasks()
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream")
|
||||||
background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await abort_request()
|
await engine.abort(request_id)
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
@@ -331,7 +346,7 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def create_completion(raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
"""Completion API similar to OpenAI's API.
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/completions/create
|
See https://platform.openai.com/docs/api-reference/completions/create
|
||||||
@@ -344,7 +359,6 @@ async def create_completion(raw_request: Request):
|
|||||||
suffix)
|
suffix)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
"""
|
"""
|
||||||
request = CompletionRequest(**await raw_request.json())
|
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
@@ -362,24 +376,41 @@ async def create_completion(raw_request: Request):
|
|||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
if request.logit_bias is not None:
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||||
# TODO: support logit_bias in vLLM engine.
|
# TODO: support logit_bias in vLLM engine.
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
|
|
||||||
|
use_token_ids = False
|
||||||
if isinstance(request.prompt, list):
|
if isinstance(request.prompt, list):
|
||||||
if len(request.prompt) == 0:
|
if len(request.prompt) == 0:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"please provide at least one prompt")
|
"please provide at least one prompt")
|
||||||
if len(request.prompt) > 1:
|
first_element = request.prompt[0]
|
||||||
return create_error_response(
|
if isinstance(first_element, int):
|
||||||
HTTPStatus.BAD_REQUEST,
|
use_token_ids = True
|
||||||
"multiple prompts in a batch is not currently supported")
|
prompt = request.prompt
|
||||||
prompt = request.prompt[0]
|
elif isinstance(first_element, (str, list)):
|
||||||
|
# TODO: handles multiple prompt case in list[list[int]]
|
||||||
|
if len(request.prompt) > 1:
|
||||||
|
return create_error_response(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
"multiple prompts in a batch is not currently supported")
|
||||||
|
use_token_ids = not isinstance(first_element, str)
|
||||||
|
prompt = request.prompt[0]
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
|
|
||||||
|
if use_token_ids:
|
||||||
|
_, error_check_ret = await check_length(request, prompt_ids=prompt)
|
||||||
|
else:
|
||||||
|
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
try:
|
try:
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@@ -391,15 +422,24 @@ async def create_completion(raw_request: Request):
|
|||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
|
stop_token_ids=request.stop_token_ids,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
if use_token_ids:
|
||||||
|
result_generator = engine.generate(None,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
prompt_token_ids=prompt)
|
||||||
|
else:
|
||||||
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
|
token_ids)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use beam search.
|
# results. In addition, we do not stream the results when use beam search.
|
||||||
@@ -407,9 +447,6 @@ async def create_completion(raw_request: Request):
|
|||||||
and (request.best_of is None or request.n == request.best_of)
|
and (request.best_of is None or request.n == request.best_of)
|
||||||
and not request.use_beam_search)
|
and not request.use_beam_search)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
|
||||||
await engine.abort(request_id)
|
|
||||||
|
|
||||||
def create_stream_response_json(
|
def create_stream_response_json(
|
||||||
index: int,
|
index: int,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -465,23 +502,19 @@ async def create_completion(raw_request: Request):
|
|||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
)
|
)
|
||||||
yield f"data: {response_json}\n\n"
|
yield f"data: {response_json}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
background_tasks.add_task(abort_request)
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream")
|
||||||
background=background_tasks)
|
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await abort_request()
|
await engine.abort(request_id)
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
@@ -582,10 +615,12 @@ if __name__ == "__main__":
|
|||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
|
max_model_len = engine_model_config.max_model_len
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||||
tokenizer_mode=engine_args.tokenizer_mode)
|
tokenizer_mode=engine_args.tokenizer_mode,
|
||||||
|
trust_remote_code=engine_args.trust_remote_code)
|
||||||
|
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
@@ -70,11 +70,14 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_k: Optional[int] = -1
|
top_k: Optional[int] = -1
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
prompt: Union[str, List[str]]
|
# a string, array of strings, array of tokens, or array of token arrays
|
||||||
|
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||||
suffix: Optional[str] = None
|
suffix: Optional[str] = None
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = 1.0
|
||||||
@@ -93,6 +96,8 @@ class CompletionRequest(BaseModel):
|
|||||||
top_k: Optional[int] = -1
|
top_k: Optional[int] = -1
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
@@ -29,6 +29,7 @@ class InputMetadata:
|
|||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
self.seq_data = seq_data
|
self.seq_data = seq_data
|
||||||
@@ -38,6 +39,24 @@ class InputMetadata:
|
|||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
|
||||||
|
self.to_cache = None
|
||||||
|
if sliding_window is not None:
|
||||||
|
# We need to keep the positions of sliding windows within
|
||||||
|
# the key / value tables, this is helpful to know which
|
||||||
|
# elements we need to cache and where
|
||||||
|
to_cache, start_idx = [], 0
|
||||||
|
for prompt_len in self.prompt_lens:
|
||||||
|
to_cache.extend(
|
||||||
|
range(
|
||||||
|
start_idx + max(0, prompt_len - sliding_window),
|
||||||
|
start_idx + prompt_len,
|
||||||
|
))
|
||||||
|
start_idx += prompt_len
|
||||||
|
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||||
|
self.to_cache = torch.tensor(to_cache,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.slot_mapping.device)
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
self.num_prompts = len(prompt_lens)
|
||||||
self.num_prompt_tokens = sum(prompt_lens)
|
self.num_prompt_tokens = sum(prompt_lens)
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
self.num_generation_tokens = context_lens.shape[0]
|
||||||
|
|||||||
@@ -4,23 +4,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm import activation_ops
|
from vllm import activation_ops
|
||||||
|
|
||||||
_ACTIVATION_REGISTRY = {
|
|
||||||
"gelu": nn.GELU(),
|
|
||||||
# NOTE: The following GELU functions may introduce small rounding errors.
|
|
||||||
"gelu_new": nn.GELU(approximate="tanh"),
|
|
||||||
"gelu_fast": nn.GELU(approximate="tanh"),
|
|
||||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
|
||||||
"relu": nn.ReLU(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_act_fn(act_fn: str) -> nn.Module:
|
|
||||||
"""Get an activation function by name."""
|
|
||||||
act_fn = act_fn.lower()
|
|
||||||
if act_fn in _ACTIVATION_REGISTRY:
|
|
||||||
return _ACTIVATION_REGISTRY[act_fn]
|
|
||||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
@@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
|
|||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
activation_ops.silu_and_mul(out, x)
|
activation_ops.silu_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NewGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.gelu_new(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FastGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.gelu_fast(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_ACTIVATION_REGISTRY = {
|
||||||
|
"gelu": nn.GELU(),
|
||||||
|
"gelu_fast": FastGELU(),
|
||||||
|
"gelu_new": NewGELU(),
|
||||||
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||||
|
"relu": nn.ReLU(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_fn(act_fn: str) -> nn.Module:
|
||||||
|
"""Get an activation function by name."""
|
||||||
|
act_fn = act_fn.lower()
|
||||||
|
if act_fn in _ACTIVATION_REGISTRY:
|
||||||
|
return _ACTIVATION_REGISTRY[act_fn]
|
||||||
|
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Multi-head attention."""
|
"""Multi-head attention."""
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -9,10 +9,12 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
|||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
from vllm import pos_encoding_ops
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
|
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||||
|
RotaryEmbedding)
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]
|
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
|
||||||
|
|
||||||
class PagedAttention(nn.Module):
|
class PagedAttention(nn.Module):
|
||||||
@@ -20,12 +22,20 @@ class PagedAttention(nn.Module):
|
|||||||
"""GPT-style multi-head PagedAttention.
|
"""GPT-style multi-head PagedAttention.
|
||||||
|
|
||||||
This class takes flattened 1D query, key, and value tensors as input. The
|
This class takes flattened 1D query, key, and value tensors as input. The
|
||||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
input 1D tensors can either contain prompt tokens or generation tokens, in
|
||||||
generation tokens, and the paddings.
|
addition to paddings.
|
||||||
|
|
||||||
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
If the input tensors contain prompt tokens, the layout is as follows:
|
||||||
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
|
||||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
|<---------------------- num_valid_tokens ---------------------->|
|
||||||
|
|<--------------- num_prompt_tokens -------------->|
|
||||||
|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
|
||||||
|
|
||||||
|
Otherwise, the layout is as follows:
|
||||||
|
|
||||||
|
|<------------------ num_valid_tokens ------------------->|
|
||||||
|
|<------- num_generation_tokens (M) ------->|
|
||||||
|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||||
|
|
||||||
The prompts might have different lengths, while the generation tokens always
|
The prompts might have different lengths, while the generation tokens always
|
||||||
have length 1. The paddings are appended to make the input length a multiple
|
have length 1. The paddings are appended to make the input length a multiple
|
||||||
@@ -44,23 +54,42 @@ class PagedAttention(nn.Module):
|
|||||||
5. Output a flattened 1D tensor.
|
5. Output a flattened 1D tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.head_mapping = torch.repeat_interleave(
|
||||||
|
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||||
|
self.num_queries_per_kv)
|
||||||
|
|
||||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
def set_attn_bias(
|
||||||
|
self,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
del dtype # Unused.
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
prompt_lens = input_metadata.prompt_lens
|
prompt_lens = input_metadata.prompt_lens
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
||||||
input_metadata.attn_bias.append(attn_bias)
|
input_metadata.attn_bias.append(attn_bias)
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
@@ -76,10 +105,18 @@ class PagedAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
key: shape = [num_prompt_tokens, num_heads, head_size]
|
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_prompt_tokens, num_heads, head_size]
|
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
# Project the key and value tensors to the desired number of heads.
|
||||||
|
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
dim=1)
|
||||||
|
|
||||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||||
out = xops.memory_efficient_attention_forward(
|
out = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
@@ -88,7 +125,6 @@ class PagedAttention(nn.Module):
|
|||||||
attn_bias=input_metadata.attn_bias[0],
|
attn_bias=input_metadata.attn_bias[0],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
op=self.attn_op,
|
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output.copy_(out.squeeze(0))
|
output.copy_(out.squeeze(0))
|
||||||
@@ -107,9 +143,10 @@ class PagedAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
@@ -118,6 +155,7 @@ class PagedAttention(nn.Module):
|
|||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
|
self.head_mapping,
|
||||||
self.scale,
|
self.scale,
|
||||||
input_metadata.block_tables,
|
input_metadata.block_tables,
|
||||||
input_metadata.context_lens,
|
input_metadata.context_lens,
|
||||||
@@ -143,11 +181,12 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
cache_event: event to wait for the cache operations to finish.
|
cache_event: event to wait for the cache operations to finish.
|
||||||
|
|
||||||
@@ -157,8 +196,8 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
# Pre-allocate the output tensor.
|
# Pre-allocate the output tensor.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@@ -166,7 +205,9 @@ class PagedAttention(nn.Module):
|
|||||||
# Compute the attention op for prompts.
|
# Compute the attention op for prompts.
|
||||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||||
if num_prompt_tokens > 0:
|
if num_prompt_tokens > 0:
|
||||||
self.set_attn_bias(input_metadata)
|
# Prompt run.
|
||||||
|
assert input_metadata.num_generation_tokens == 0
|
||||||
|
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output[:num_prompt_tokens],
|
output[:num_prompt_tokens],
|
||||||
query[:num_prompt_tokens],
|
query[:num_prompt_tokens],
|
||||||
@@ -186,15 +227,25 @@ class PagedAttention(nn.Module):
|
|||||||
if (num_valid_tokens > 0 and key_cache is not None
|
if (num_valid_tokens > 0 and key_cache is not None
|
||||||
and value_cache is not None):
|
and value_cache is not None):
|
||||||
# The stride is 3 because the key and value are sliced from qkv.
|
# The stride is 3 because the key and value are sliced from qkv.
|
||||||
|
key_to_cache = key[:num_valid_tokens]
|
||||||
|
value_to_cache = value[:num_valid_tokens]
|
||||||
|
slot_mapping = input_metadata.slot_mapping
|
||||||
|
if input_metadata.to_cache is not None:
|
||||||
|
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||||
|
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||||
|
slot_mapping = slot_mapping[input_metadata.to_cache]
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key[:num_valid_tokens],
|
key_to_cache,
|
||||||
value[:num_valid_tokens],
|
value_to_cache,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
input_metadata.slot_mapping,
|
slot_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
|
# Decoding run.
|
||||||
|
assert input_metadata.num_prompt_tokens == 0
|
||||||
assert key_cache is not None and value_cache is not None, (
|
assert key_cache is not None and value_cache is not None, (
|
||||||
"key_cache and value_cache must be provided when "
|
"key_cache and value_cache must be provided when "
|
||||||
"generating tokens.")
|
"generating tokens.")
|
||||||
@@ -210,7 +261,7 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PagedAttentionWithRoPE(PagedAttention):
|
class PagedAttentionWithRoPE(PagedAttention):
|
||||||
"""PagedAttention with GPT-NeoX style rotary embedding."""
|
"""PagedAttention with rotary positional embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -220,24 +271,33 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_heads, head_size, scale)
|
super().__init__(num_heads,
|
||||||
|
head_size,
|
||||||
# Create the cos and sin cache.
|
scale,
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
num_kv_heads,
|
||||||
t = torch.arange(max_position).float()
|
sliding_window=sliding_window)
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
if rope_scaling is None:
|
||||||
cos = freqs.cos()
|
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
|
||||||
sin = freqs.sin()
|
max_position, base,
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
is_neox_style)
|
||||||
|
else:
|
||||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
scaling_type = rope_scaling["type"]
|
||||||
# initializing the model.
|
scaling_factor = rope_scaling["factor"]
|
||||||
# TODO(woosuk): Make it more robust.
|
if scaling_type == "linear":
|
||||||
torch_dtype = torch.get_default_dtype()
|
self.rotary_emb = LinearScalingRotaryEmbedding(
|
||||||
cache = cache.to(torch_dtype)
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
# Embedding size: [max_position, rotary_dim]
|
scaling_factor)
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
elif scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
|
scaling_factor)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -254,12 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
positions: shape = [num_tokens]
|
positions: shape = [num_tokens]
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
cache_event: event to wait for the cache operations to finish.
|
cache_event: event to wait for the cache operations to finish.
|
||||||
|
|
||||||
@@ -269,13 +330,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
)
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -290,26 +345,31 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
class PagedAttentionWithALiBi(PagedAttention):
|
class PagedAttentionWithALiBi(PagedAttention):
|
||||||
"""PagedAttention with ALiBi attention bias."""
|
"""PagedAttention with ALiBi attention bias."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
num_heads: int,
|
||||||
num_heads: int,
|
head_size: int,
|
||||||
head_size: int,
|
scale: float,
|
||||||
scale: float,
|
slopes: List[float],
|
||||||
slopes: List[float],
|
num_kv_heads: Optional[int] = None) -> None:
|
||||||
) -> None:
|
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||||
super().__init__(num_heads, head_size, scale)
|
|
||||||
assert len(slopes) == num_heads
|
assert len(slopes) == num_heads
|
||||||
|
|
||||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||||
|
dtype: torch.dtype) -> None:
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
# Generates ALiBi mask for each prompt.
|
# Generates ALiBi mask for each prompt.
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
for prompt_len in input_metadata.prompt_lens:
|
||||||
bias = torch.arange(prompt_len)
|
bias = torch.arange(prompt_len, dtype=dtype)
|
||||||
|
# Note(zhuohan): HF uses
|
||||||
|
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||||
|
# here. We find that both biases give the same results, but
|
||||||
|
# the bias below more accurately follows the original ALiBi
|
||||||
|
# paper.
|
||||||
bias = bias[None, :] - bias[:, None]
|
bias = bias[None, :] - bias[:, None]
|
||||||
bias = bias.to(self.alibi_slopes.device)
|
bias = bias.to(self.alibi_slopes.device)
|
||||||
|
|
||||||
@@ -317,11 +377,13 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
# be sliced from a tensor whose length is a multiple of 8.
|
# be sliced from a tensor whose length is a multiple of 8.
|
||||||
padded_len = (prompt_len + 7) // 8 * 8
|
padded_len = (prompt_len + 7) // 8 * 8
|
||||||
bias = torch.empty(
|
bias = torch.empty(
|
||||||
|
1, # batch_size
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
padded_len,
|
prompt_len,
|
||||||
padded_len,
|
padded_len,
|
||||||
device=self.alibi_slopes.device,
|
device=self.alibi_slopes.device,
|
||||||
)[:, :prompt_len, :prompt_len].copy_(bias)
|
dtype=dtype,
|
||||||
|
)[:, :, :, :prompt_len].copy_(bias)
|
||||||
bias.mul_(self.alibi_slopes[:, None, None])
|
bias.mul_(self.alibi_slopes[:, None, None])
|
||||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||||
input_metadata.attn_bias.append(attn_bias)
|
input_metadata.attn_bias.append(attn_bias)
|
||||||
@@ -339,10 +401,17 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
key: shape = [num_prompt_tokens, num_heads, head_size]
|
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_prompt_tokens, num_heads, head_size]
|
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
# Project the key and value tensors to the desired number of heads.
|
||||||
|
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
dim=1)
|
||||||
|
|
||||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||||
# lengths with custom attention bias, we process each prompt one by
|
# lengths with custom attention bias, we process each prompt one by
|
||||||
# one. This is inefficient, especially when we have many short prompts.
|
# one. This is inefficient, especially when we have many short prompts.
|
||||||
@@ -356,7 +425,6 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
attn_bias=input_metadata.attn_bias[i],
|
attn_bias=input_metadata.attn_bias[i],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
op=self.attn_op,
|
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output[start:end].copy_(out.squeeze(0))
|
output[start:end].copy_(out.squeeze(0))
|
||||||
@@ -376,9 +444,10 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
@@ -387,6 +456,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
|
self.head_mapping,
|
||||||
self.scale,
|
self.scale,
|
||||||
input_metadata.block_tables,
|
input_metadata.block_tables,
|
||||||
input_metadata.context_lens,
|
input_metadata.context_lens,
|
||||||
|
|||||||
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||||
|
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear)
|
||||||
|
|
||||||
|
_QUANTIZED_LINEAR_REGISTRY = {
|
||||||
|
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLinear:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
|
||||||
|
quant_config = kwargs.get("quant_config", None)
|
||||||
|
if quant_config is None:
|
||||||
|
return ColumnParallelLinear(*args, **kwargs)
|
||||||
|
|
||||||
|
name = quant_config.get_name()
|
||||||
|
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||||
|
raise ValueError(f"No quantized linear is found for {name}")
|
||||||
|
|
||||||
|
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
|
||||||
|
return quant_linear_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def row(cls, *args, **kwargs) -> RowParallelLinear:
|
||||||
|
quant_config = kwargs.get("quant_config", None)
|
||||||
|
if quant_config is None:
|
||||||
|
return RowParallelLinear(*args, **kwargs)
|
||||||
|
|
||||||
|
name = quant_config.get_name()
|
||||||
|
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||||
|
raise ValueError(f"No quantized linear is found for {name}")
|
||||||
|
|
||||||
|
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
|
||||||
|
return quant_linear_cls(*args, **kwargs)
|
||||||
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import quantization_ops
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear)
|
||||||
|
|
||||||
|
|
||||||
|
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert self.input_size % self.quant_config.weight_bits == 0
|
||||||
|
assert (self.output_size_per_partition %
|
||||||
|
self.quant_config.pack_factor == 0)
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size,
|
||||||
|
self.output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.group_size,
|
||||||
|
self.output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.group_size,
|
||||||
|
self.output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pack_factor = self.quant_config.pack_factor
|
||||||
|
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
|
self.qzeros, pack_factor)
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class AWQRowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert (self.input_size_per_partition %
|
||||||
|
self.quant_config.weight_bits == 0)
|
||||||
|
assert self.output_size % self.quant_config.pack_factor == 0
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition,
|
||||||
|
self.output_size // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.group_size,
|
||||||
|
self.output_size // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.group_size,
|
||||||
|
self.output_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pack_factor = self.quant_config.pack_factor
|
||||||
|
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
|
self.qzeros, pack_factor)
|
||||||
|
return out.reshape(out_shape)
|
||||||
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Rotary Positional Embeddings."""
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_size = head_size
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
|
cache = self._compute_cos_sin_cache()
|
||||||
|
cache = cache.to(torch.get_default_dtype())
|
||||||
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
|
"""Compute the inverse frequency."""
|
||||||
|
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||||
|
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||||
|
# avoid numerical issues with large base values (e.g., 10000000).
|
||||||
|
# This may cause a slight numerical difference between the HF
|
||||||
|
# implementation and ours.
|
||||||
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
|
# use CPU to compute the cache and then move it to GPU. However, we
|
||||||
|
# create the cache on GPU for faster initialization. This may cause
|
||||||
|
# a slight numerical difference between the HF implementation and ours.
|
||||||
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
|
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||||
|
self.rotary_dim))
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
"""Compute the cos and sin cache."""
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
t = torch.arange(self.max_position_embeddings,
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# pos_encoding_ops.rotary_embedding() is an in-place operation that
|
||||||
|
# updates the query and key tensors.
|
||||||
|
pos_encoding_ops.rotary_embedding(positions, query, key,
|
||||||
|
self.head_size, self.cos_sin_cache,
|
||||||
|
self.is_neox_style)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with linear scaling.
|
||||||
|
|
||||||
|
Credits to the Reddit user /u/kaiokendev
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
|
# maximum length before applying the rope scaling.
|
||||||
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
|
max_len = self.max_position_embeddings * self.scaling_factor
|
||||||
|
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||||
|
|
||||||
|
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
|
# maximum length before applying the rope scaling.
|
||||||
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
|
max_len = self.max_position_embeddings * self.scaling_factor
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
||||||
|
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||||
|
(self.rotary_dim - 2))
|
||||||
|
inv_freq = self._compute_inv_freq(base)
|
||||||
|
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
gather_from_tensor_model_parallel_region)
|
gather_from_tensor_model_parallel_region)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@@ -38,15 +37,14 @@ class Sampler(nn.Module):
|
|||||||
embedding: torch.Tensor,
|
embedding: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> Dict[int, SequenceOutputs]:
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> SamplerOutput:
|
||||||
# Get the hidden states that we use for sampling.
|
# Get the hidden states that we use for sampling.
|
||||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = torch.matmul(hidden_states, embedding.t())
|
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||||
logits = gather_from_tensor_model_parallel_region(logits)
|
self.vocab_size)
|
||||||
# Remove paddings in vocab (if any).
|
|
||||||
logits = logits[:, :self.vocab_size]
|
|
||||||
|
|
||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
output_tokens = _get_output_tokens(input_metadata)
|
output_tokens = _get_output_tokens(input_metadata)
|
||||||
@@ -56,7 +54,7 @@ class Sampler(nn.Module):
|
|||||||
assert len(presence_penalties) == logits.shape[0]
|
assert len(presence_penalties) == logits.shape[0]
|
||||||
assert len(frequency_penalties) == logits.shape[0]
|
assert len(frequency_penalties) == logits.shape[0]
|
||||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||||
frequency_penalties, self.vocab_size)
|
frequency_penalties)
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
temperatures = _get_temperatures(input_metadata)
|
temperatures = _get_temperatures(input_metadata)
|
||||||
@@ -68,36 +66,66 @@ class Sampler(nn.Module):
|
|||||||
# Use in-place division to avoid creating a new tensor.
|
# Use in-place division to avoid creating a new tensor.
|
||||||
logits.div_(t.unsqueeze(dim=1))
|
logits.div_(t.unsqueeze(dim=1))
|
||||||
|
|
||||||
# We use float32 for probabilities and log probabilities.
|
|
||||||
# Compute the probabilities.
|
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
||||||
# Compute the log probabilities (before applying top-p and top-k).
|
|
||||||
logprobs = torch.log(probs)
|
|
||||||
|
|
||||||
# Apply top-p and top-k truncation.
|
# Apply top-p and top-k truncation.
|
||||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||||
if do_top_p or do_top_k:
|
if do_top_p or do_top_k:
|
||||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
||||||
|
|
||||||
|
# We use float32 for probabilities and log probabilities.
|
||||||
|
# Compute the probabilities.
|
||||||
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
# Compute the log probabilities.
|
||||||
|
# Use log_softmax to ensure numerical stability.
|
||||||
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
return _sample(probs, logprobs, input_metadata)
|
return _sample(probs, logprobs, input_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor],
|
||||||
|
vocab_size: int) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = gather_from_tensor_model_parallel_region(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
logits = logits[:, :vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _prune_hidden_states(
|
def _prune_hidden_states(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
last_token_indices = {t: [] for t in SamplingType}
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
last_token_indicies: List[int] = []
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
seq_ids, sampling_params = seq_group
|
||||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
sampling_type = sampling_params.sampling_type
|
||||||
start_idx += prompt_len
|
if i < input_metadata.num_prompts:
|
||||||
last_token_indicies.extend(
|
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
prompt_len = input_metadata.prompt_lens[i]
|
||||||
return hidden_states[last_token_indicies]
|
last_token_indices[sampling_type].append(start_idx + prompt_len -
|
||||||
|
1)
|
||||||
|
start_idx += prompt_len
|
||||||
|
else:
|
||||||
|
num_seqs = len(seq_ids)
|
||||||
|
last_token_indices[sampling_type].extend(
|
||||||
|
range(start_idx, start_idx + num_seqs))
|
||||||
|
start_idx += num_seqs
|
||||||
|
|
||||||
|
all_last_token_indices = []
|
||||||
|
for sampling_type in SamplingType:
|
||||||
|
all_last_token_indices.extend(last_token_indices[sampling_type])
|
||||||
|
all_last_token_indices = torch.tensor(all_last_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=hidden_states.device)
|
||||||
|
return hidden_states.index_select(0, all_last_token_indices)
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
@@ -105,37 +133,22 @@ def _get_penalties(
|
|||||||
# Collect the presence and frequency penalties.
|
# Collect the presence and frequency penalties.
|
||||||
presence_penalties: List[float] = []
|
presence_penalties: List[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
p = sampling_params.presence_penalty
|
p = sampling_params.presence_penalty
|
||||||
f = sampling_params.frequency_penalty
|
f = sampling_params.frequency_penalty
|
||||||
if i < input_metadata.num_prompts:
|
presence_penalties += [p] * len(seq_ids)
|
||||||
# A prompt input.
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
presence_penalties.append(p)
|
|
||||||
frequency_penalties.append(f)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
presence_penalties += [p] * len(seq_ids)
|
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
|
||||||
return presence_penalties, frequency_penalties
|
return presence_penalties, frequency_penalties
|
||||||
|
|
||||||
|
|
||||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[List[int]] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, _ = seq_group
|
seq_ids, _ = seq_group
|
||||||
if i < input_metadata.num_prompts:
|
for seq_id in seq_ids:
|
||||||
# A prompt input.
|
|
||||||
# NOTE: While the prompt input usually has no output tokens,
|
|
||||||
# it may have output tokens in the case of recomputation.
|
|
||||||
seq_id = seq_ids[0]
|
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
seq_data = input_metadata.seq_data[seq_id]
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
|
||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
@@ -144,52 +157,56 @@ def _apply_penalties(
|
|||||||
output_tokens: List[List[int]],
|
output_tokens: List[List[int]],
|
||||||
presence_penalties: List[float],
|
presence_penalties: List[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
vocab_size: int,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs = logits.shape[0]
|
num_seqs, vocab_size = logits.shape
|
||||||
# Collect the indices of sequences that have non-zero penalties.
|
|
||||||
indices = []
|
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
if not output_tokens[i]:
|
if not output_tokens[i]:
|
||||||
continue
|
continue
|
||||||
p = presence_penalties[i]
|
p = presence_penalties[i]
|
||||||
f = frequency_penalties[i]
|
f = frequency_penalties[i]
|
||||||
if p < _SAMPLING_EPS and f < _SAMPLING_EPS:
|
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||||
continue
|
continue
|
||||||
indices.append(i)
|
break
|
||||||
|
else:
|
||||||
# Return early if all sequences have zero penalties.
|
# Return early if all sequences have zero penalties.
|
||||||
if not indices:
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
bin_counts = []
|
max_output_len = max(len(tokens) for tokens in output_tokens)
|
||||||
for i in indices:
|
padded_output_tokens = [
|
||||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
tokens + [vocab_size] * (max_output_len - len(tokens))
|
||||||
bin_counts = np.stack(bin_counts, axis=0)
|
for tokens in output_tokens
|
||||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
]
|
||||||
device=logits.device)
|
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logits.device)
|
||||||
|
|
||||||
|
# Compute the bin counts for the output tokens.
|
||||||
|
# vocab_size + 1 for padding.
|
||||||
|
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=logits.device)
|
||||||
|
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||||
|
torch.ones_like(output_tokens_tensor))
|
||||||
|
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||||
|
|
||||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
|
||||||
frequency_penalties = torch.tensor(frequency_penalties,
|
frequency_penalties = torch.tensor(frequency_penalties,
|
||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
presence_penalties = [presence_penalties[i] for i in indices]
|
|
||||||
presence_penalties = torch.tensor(presence_penalties,
|
presence_penalties = torch.tensor(presence_penalties,
|
||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
# We follow the definition in OpenAI API.
|
# We follow the definition in OpenAI API.
|
||||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
||||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||||
# Collect the temperatures for the logits.
|
# Collect the temperatures for the logits.
|
||||||
temperatures: List[float] = []
|
temperatures: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
if temperature < _SAMPLING_EPS:
|
if temperature < _SAMPLING_EPS:
|
||||||
@@ -197,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
|||||||
# (i.e., greedy sampling or beam search).
|
# (i.e., greedy sampling or beam search).
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
# Set the temperature to 1 to avoid division by zero.
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
|
temperatures += [temperature] * len(seq_ids)
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
# A prompt input.
|
|
||||||
temperatures.append(temperature)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
temperatures += [temperature] * len(seq_ids)
|
|
||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
|
|
||||||
@@ -213,221 +224,279 @@ def _get_top_p_top_k(
|
|||||||
) -> Tuple[List[float], List[int]]:
|
) -> Tuple[List[float], List[int]]:
|
||||||
top_ps: List[float] = []
|
top_ps: List[float] = []
|
||||||
top_ks: List[int] = []
|
top_ks: List[int] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
top_p = sampling_params.top_p
|
top_p = sampling_params.top_p
|
||||||
# k should not be greater than the vocab size.
|
# k should not be greater than the vocab size.
|
||||||
top_k = min(sampling_params.top_k, vocab_size)
|
top_k = min(sampling_params.top_k, vocab_size)
|
||||||
# k=-1 means no truncation.
|
# k=-1 means no truncation.
|
||||||
top_k = vocab_size if top_k == -1 else top_k
|
top_k = vocab_size if top_k == -1 else top_k
|
||||||
if i < input_metadata.num_prompts:
|
top_ps += [top_p] * len(seq_ids)
|
||||||
# A prompt input.
|
top_ks += [top_k] * len(seq_ids)
|
||||||
top_ps.append(top_p)
|
|
||||||
top_ks.append(top_k)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
top_ps += [top_p] * len(seq_ids)
|
|
||||||
top_ks += [top_k] * len(seq_ids)
|
|
||||||
return top_ps, top_ks
|
return top_ps, top_ks
|
||||||
|
|
||||||
|
|
||||||
def _apply_top_p_top_k(
|
def _apply_top_p_top_k(
|
||||||
probs: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
top_ps: List[float],
|
top_ps: List[float],
|
||||||
top_ks: List[int],
|
top_ks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
||||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||||
probs_sort[top_p_mask] = 0.0
|
logits_sort[top_p_mask] = -float("inf")
|
||||||
|
|
||||||
# Apply top-k.
|
# Apply top-k.
|
||||||
# Create a mask for the top-k elements.
|
# Create a mask for the top-k elements.
|
||||||
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
||||||
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
||||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
||||||
probs_sort[top_k_mask] = 0.0
|
logits_sort[top_k_mask] = -float("inf")
|
||||||
|
|
||||||
# Re-sort the probabilities.
|
# Re-sort the probabilities.
|
||||||
probs = torch.gather(probs_sort,
|
logits = torch.gather(logits_sort,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=torch.argsort(probs_idx, dim=-1))
|
index=torch.argsort(logits_idx, dim=-1))
|
||||||
return probs
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_topk_logprobs(
|
def _get_topk_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
) -> Dict[int, float]:
|
) -> List[Dict[int, float]]:
|
||||||
|
num_seqs = logprobs.size(0)
|
||||||
if num_logprobs is None or num_logprobs == 0:
|
if num_logprobs is None or num_logprobs == 0:
|
||||||
return {}
|
return [{} for _ in range(num_seqs)]
|
||||||
|
|
||||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
|
||||||
if num_logprobs == 1:
|
num_logprobs,
|
||||||
topk_logprobs = [topk_logprobs.item()]
|
dim=-1)
|
||||||
topk_ids = [topk_ids.item()]
|
all_topk_logprobs = all_topk_logprobs.cpu()
|
||||||
else:
|
all_topk_ids = all_topk_ids.cpu()
|
||||||
topk_logprobs = topk_logprobs.tolist()
|
all_token_to_logprob = []
|
||||||
topk_ids = topk_ids.tolist()
|
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
|
||||||
|
token_to_logprob: Dict[int, float] = {}
|
||||||
token_to_logprob: Dict[int, float] = {}
|
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
token_to_logprob[token_id.item()] = logprob.item()
|
||||||
token_to_logprob[token_id] = logprob
|
all_token_to_logprob.append(token_to_logprob)
|
||||||
return token_to_logprob
|
return all_token_to_logprob
|
||||||
|
|
||||||
|
|
||||||
def _sample_from_prompt(
|
def _build_sequence_outputs(
|
||||||
prob: torch.Tensor,
|
parent_ids: List[int],
|
||||||
sampling_params: SamplingParams,
|
next_token_ids: List[int],
|
||||||
) -> List[int]:
|
selected_token_logprobs: torch.Tensor,
|
||||||
if sampling_params.use_beam_search:
|
parent_seq_ids: List[int],
|
||||||
# Beam search.
|
parent_logprobs: torch.Tensor,
|
||||||
beam_width = sampling_params.best_of
|
num_output_logprobs: Optional[int],
|
||||||
_, next_token_ids = torch.topk(prob, beam_width)
|
) -> List[SequenceOutputs]:
|
||||||
next_token_ids = next_token_ids.tolist()
|
# Get top-k log probabilities for the next tokens.
|
||||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
|
||||||
# Greedy sampling.
|
seq_outputs: List[SequenceOutputs] = []
|
||||||
assert sampling_params.best_of == 1
|
for parent_id, next_token_id, token_logprob in zip(
|
||||||
next_token_id = torch.argmax(prob)
|
parent_ids, next_token_ids, selected_token_logprobs):
|
||||||
next_token_ids = [next_token_id.item()]
|
output_logprobs = next_logprobs[parent_id].copy()
|
||||||
else:
|
output_logprobs[next_token_id] = token_logprob
|
||||||
# Random sampling.
|
seq_outputs.append(
|
||||||
# Sample `best_of` tokens for the prompt.
|
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
|
||||||
num_seqs = sampling_params.best_of
|
output_logprobs))
|
||||||
next_token_ids = torch.multinomial(prob,
|
return seq_outputs
|
||||||
num_samples=num_seqs,
|
|
||||||
replacement=True)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
return next_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_from_generation_tokens(
|
def _greedy_sample(
|
||||||
seq_ids: List[int],
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
probs: torch.Tensor,
|
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
seq_logprobs: List[float],
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
sampling_params: SamplingParams,
|
samples = torch.argmax(logprobs, dim=-1).cpu()
|
||||||
) -> Tuple[List[int], List[int]]:
|
sample_idx = 0
|
||||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
results = []
|
||||||
# len(seq_ids) because some sequences in the group might have
|
for seq_group in selected_seq_groups:
|
||||||
# been already terminated.
|
seq_ids, _ = seq_group
|
||||||
if sampling_params.use_beam_search:
|
num_parent_seqs = len(seq_ids)
|
||||||
# Beam search.
|
assert num_parent_seqs == 1, (
|
||||||
# Add cumulative logprobs for the sequences in the group.
|
"Greedy sampling should have only one seq.")
|
||||||
seq_logprobs = torch.tensor(seq_logprobs,
|
parent_ids = list(range(num_parent_seqs))
|
||||||
dtype=torch.float,
|
next_token_ids = [samples[sample_idx].item()]
|
||||||
device=logprobs.device)
|
results.append((next_token_ids, parent_ids))
|
||||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == logprobs.size(0)
|
||||||
|
return results
|
||||||
|
|
||||||
vocab_size = logprobs.size(-1)
|
|
||||||
beam_width = len(seq_ids)
|
|
||||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
|
||||||
topk_ids = topk_ids.tolist()
|
|
||||||
seq_idx = [i // vocab_size for i in topk_ids]
|
|
||||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
|
||||||
token_ids = [i % vocab_size for i in topk_ids]
|
|
||||||
|
|
||||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
def _random_sample(
|
||||||
outstanding_beams: List[Tuple[int, int]] = []
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
# If a beam survives, continue with it.
|
is_prompts: List[bool],
|
||||||
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
probs: torch.Tensor,
|
||||||
if seq_id not in beam_outputs:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
beam_outputs[seq_id] = (seq_id, token_id)
|
# Find the maximum best_of value of the prompt phase requests.
|
||||||
else:
|
max_best_of = 1
|
||||||
outstanding_beams.append((seq_id, token_id))
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
|
if is_prompt:
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||||
|
random_samples = torch.multinomial(probs,
|
||||||
|
num_samples=max_best_of,
|
||||||
|
replacement=True).cpu()
|
||||||
|
sample_idx = 0
|
||||||
|
results = []
|
||||||
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
num_parent_seqs = len(seq_ids)
|
||||||
|
if is_prompt:
|
||||||
|
# Prompt phase.
|
||||||
|
assert num_parent_seqs == 1, (
|
||||||
|
"Prompt input should have only one seq.")
|
||||||
|
parent_ids = [0] * sampling_params.best_of
|
||||||
|
next_token_ids = random_samples[
|
||||||
|
sample_idx, :sampling_params.best_of].tolist()
|
||||||
|
else:
|
||||||
|
# Generation phase.
|
||||||
|
parent_ids = list(range(num_parent_seqs))
|
||||||
|
next_token_ids = random_samples[sample_idx:sample_idx +
|
||||||
|
num_parent_seqs, 0].tolist()
|
||||||
|
results.append((next_token_ids, parent_ids))
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == probs.size(0)
|
||||||
|
return results
|
||||||
|
|
||||||
# If a beam is discarded, fork another beam.
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
if seq_id not in beam_outputs:
|
|
||||||
beam_outputs[seq_id] = outstanding_beams.pop()
|
|
||||||
assert not outstanding_beams
|
|
||||||
|
|
||||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
def _beam_search_sample(
|
||||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
is_prompts: List[bool],
|
||||||
# Greedy sampling.
|
seq_data: Dict[int, SequenceData],
|
||||||
assert len(seq_ids) == 1
|
logprobs: torch.Tensor,
|
||||||
next_token_id = torch.argmax(probs, dim=-1)
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
next_token_ids = [int(next_token_id.item())]
|
# We sample 2 * beam_width candidates to make sure that with high
|
||||||
parent_seq_ids = seq_ids
|
# probability we can get `beam_width` candidates in addition to
|
||||||
else:
|
# the finished sequences for the next iteration. See
|
||||||
# Random sampling.
|
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||||
# Sample 1 token for each sequence in the group.
|
# for details. See also HF reference:
|
||||||
next_token_ids = torch.multinomial(probs,
|
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||||
num_samples=1,
|
#
|
||||||
replacement=True)
|
# Note: Beam search is not vectorized, so its speed can be slower than
|
||||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
# other sampling methods.
|
||||||
parent_seq_ids = seq_ids
|
sample_idx = 0
|
||||||
return parent_seq_ids, next_token_ids
|
results = []
|
||||||
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
num_parent_seqs = len(seq_ids)
|
||||||
|
beam_width = sampling_params.best_of
|
||||||
|
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||||
|
if is_prompt:
|
||||||
|
# Prompt phase.
|
||||||
|
assert num_parent_seqs == 1, (
|
||||||
|
"Prompt input should have only one seq.")
|
||||||
|
parent_ids = [0] * (2 * beam_width)
|
||||||
|
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
||||||
|
2 * beam_width)
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
# Generation phase.
|
||||||
|
cumulative_logprobs = [
|
||||||
|
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
||||||
|
]
|
||||||
|
cumulative_logprobs = torch.tensor(
|
||||||
|
cumulative_logprobs,
|
||||||
|
dtype=torch.float,
|
||||||
|
device=seq_group_logprobs.device)
|
||||||
|
seq_group_logprobs = (seq_group_logprobs +
|
||||||
|
cumulative_logprobs.unsqueeze(dim=1))
|
||||||
|
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||||
|
2 * beam_width)
|
||||||
|
topk_ids = topk_ids.tolist()
|
||||||
|
vocab_size = seq_group_logprobs.size(-1)
|
||||||
|
parent_ids = [i // vocab_size for i in topk_ids]
|
||||||
|
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||||
|
results.append((next_token_ids, parent_ids))
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == logprobs.size(0)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
seq_outputs: Dict[int, SequenceOutputs] = {}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
|
category_num_tokens = {t: 0 for t in SamplingType}
|
||||||
# TODO(woosuk): Optimize.
|
|
||||||
idx = 0
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
if i < input_metadata.num_prompts:
|
sampling_type = sampling_params.sampling_type
|
||||||
# Generate the next tokens for a prompt input.
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
assert len(seq_ids) == sampling_params.best_of
|
num_seqs = len(seq_ids)
|
||||||
prob = probs[idx]
|
category_num_tokens[sampling_type] += num_seqs
|
||||||
logprob = logprobs[idx]
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
category_start_idx = 0
|
||||||
# Get top-k log probabilities for the next tokens.
|
for sampling_type in SamplingType:
|
||||||
next_logprobs = _get_topk_logprobs(logprob,
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||||
sampling_params.logprobs)
|
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||||
|
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||||
# Build the output.
|
num_tokens = category_num_tokens[sampling_type]
|
||||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
if num_tokens == 0:
|
||||||
output_logprobs = next_logprobs.copy()
|
continue
|
||||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
category_logprobs = logprobs[category_start_idx:category_start_idx +
|
||||||
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
|
num_tokens]
|
||||||
next_token_id,
|
category_probs = probs[category_start_idx:category_start_idx +
|
||||||
output_logprobs)
|
num_tokens]
|
||||||
|
if sampling_type == SamplingType.GREEDY:
|
||||||
|
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||||
|
elif sampling_type == SamplingType.RANDOM:
|
||||||
|
sample_results = _random_sample(seq_groups, is_prompts,
|
||||||
|
category_probs)
|
||||||
|
elif sampling_type == SamplingType.BEAM:
|
||||||
|
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||||
|
input_metadata.seq_data,
|
||||||
|
category_logprobs)
|
||||||
else:
|
else:
|
||||||
# Generate the next tokens for generation tokens.
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||||
prob = probs[idx:idx + len(seq_ids)]
|
|
||||||
logprob = logprobs[idx:idx + len(seq_ids)]
|
|
||||||
idx += len(seq_ids)
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Batched query for logprobs of selected token
|
||||||
seq_logprobs = [
|
batched_logprobs_query_seq_indices: List[int] = []
|
||||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
batched_logprobs_query_token_indices: List[int] = []
|
||||||
for seq_id in seq_ids
|
sample_idx = 0
|
||||||
]
|
for seq_group_id, seq_group, sample_result in zip(
|
||||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
seq_group_ids, seq_groups, sample_results):
|
||||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
seq_ids, sampling_params = seq_group
|
||||||
|
next_token_ids, parent_ids = sample_result
|
||||||
|
num_parent_seqs = len(seq_ids)
|
||||||
|
batched_logprobs_query_seq_indices.extend(
|
||||||
|
[sample_idx + parent_id for parent_id in parent_ids])
|
||||||
|
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
assert sample_idx == num_tokens
|
||||||
|
batched_logprobs_query_result = category_logprobs[[
|
||||||
|
batched_logprobs_query_seq_indices,
|
||||||
|
batched_logprobs_query_token_indices
|
||||||
|
]].tolist()
|
||||||
|
|
||||||
# Get top-k log probabilities for the next tokens.
|
# Build the sequence outputs.
|
||||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
sample_idx = 0
|
||||||
for j, seq_id in enumerate(seq_ids):
|
result_idx = 0
|
||||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
for seq_group_id, seq_group, sample_result in zip(
|
||||||
logprob[j], sampling_params.logprobs)
|
seq_group_ids, seq_groups, sample_results):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
next_token_ids, parent_ids = sample_result
|
||||||
|
num_results = len(next_token_ids)
|
||||||
|
num_parent_seqs = len(seq_ids)
|
||||||
|
parent_logprobs = category_logprobs[sample_idx:sample_idx +
|
||||||
|
num_parent_seqs]
|
||||||
|
selected_token_logprobs = batched_logprobs_query_result[
|
||||||
|
result_idx:result_idx + num_results]
|
||||||
|
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
|
||||||
|
selected_token_logprobs,
|
||||||
|
seq_ids, parent_logprobs,
|
||||||
|
sampling_params.logprobs)
|
||||||
|
seq_outputs_dict[seq_group_id] = seq_output
|
||||||
|
sample_idx += num_parent_seqs
|
||||||
|
result_idx += num_results
|
||||||
|
assert sample_idx == num_tokens
|
||||||
|
category_start_idx += num_tokens
|
||||||
|
|
||||||
# Build the output.
|
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||||
for seq_id, parent_seq_id, next_token_id in zip(
|
|
||||||
seq_ids, parent_seq_ids, next_token_ids):
|
|
||||||
j = seq_ids.index(parent_seq_id)
|
|
||||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
|
||||||
output_logprobs[next_token_id] = logprob[j,
|
|
||||||
next_token_id].item()
|
|
||||||
seq_outputs[seq_id] = SequenceOutputs(
|
|
||||||
seq_id,
|
|
||||||
parent_seq_id,
|
|
||||||
next_token_id,
|
|
||||||
output_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return seq_outputs
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
|
import contextlib
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -7,20 +8,44 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
|
initialize_dummy_weights)
|
||||||
|
|
||||||
# TODO(woosuk): Lazy-load the model classes.
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
_MODEL_REGISTRY = {
|
_MODEL_REGISTRY = {
|
||||||
|
"AquilaModel": AquilaForCausalLM,
|
||||||
|
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||||
|
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||||
"BloomForCausalLM": BloomForCausalLM,
|
"BloomForCausalLM": BloomForCausalLM,
|
||||||
|
"FalconForCausalLM": FalconForCausalLM,
|
||||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||||
|
"GPTJForCausalLM": GPTJForCausalLM,
|
||||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||||
|
"InternLMForCausalLM": InternLMForCausalLM,
|
||||||
"LlamaForCausalLM": LlamaForCausalLM,
|
"LlamaForCausalLM": LlamaForCausalLM,
|
||||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||||
|
"MistralForCausalLM": MistralForCausalLM,
|
||||||
"MPTForCausalLM": MPTForCausalLM,
|
"MPTForCausalLM": MPTForCausalLM,
|
||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
|
"QWenLMHeadModel": QWenLMHeadModel,
|
||||||
|
"RWForCausalLM": FalconForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# FIXME(woosuk): Remove this once all models support quantization.
|
||||||
|
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
||||||
|
LlamaForCausalLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||||
|
"""Sets the default torch dtype to the given dtype."""
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
yield
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
|
||||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
architectures = getattr(config, "architectures", [])
|
architectures = getattr(config, "architectures", [])
|
||||||
@@ -34,19 +59,46 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|||||||
|
|
||||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||||
model_class = _get_model_architecture(model_config.hf_config)
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
torch.set_default_dtype(model_config.dtype)
|
|
||||||
|
|
||||||
# Create a model instance.
|
# Get the quantization config.
|
||||||
# The weights will be initialized as empty tensors.
|
quant_config = None
|
||||||
model = model_class(model_config.hf_config)
|
if model_config.quantization is not None:
|
||||||
if model_config.use_dummy_weights:
|
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||||
model = model.cuda()
|
raise ValueError(
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
f"Quantization is not supported for {model_class}.")
|
||||||
# random values to the weights.
|
quant_config = get_quant_config(model_config.quantization,
|
||||||
initialize_dummy_weights(model)
|
model_config.model,
|
||||||
else:
|
model_config.download_dir)
|
||||||
# Load the weights from the cached or downloaded files.
|
capability = torch.cuda.get_device_capability()
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
capability = capability[0] * 10 + capability[1]
|
||||||
model_config.use_np_weights)
|
if capability < quant_config.get_min_capability():
|
||||||
model = model.cuda()
|
raise ValueError(
|
||||||
|
f"The quantization method {model_config.quantization} is not "
|
||||||
|
"supported for the current GPU. "
|
||||||
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||||
|
f"Current capability: {capability}.")
|
||||||
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
|
if model_config.dtype not in supported_dtypes:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model_config.dtype} is not supported for quantization "
|
||||||
|
f"method {model_config.quantization}. Supported dtypes: "
|
||||||
|
f"{supported_dtypes}")
|
||||||
|
|
||||||
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
|
# Create a model instance.
|
||||||
|
# The weights will be initialized as empty tensors.
|
||||||
|
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||||
|
model = model_class(model_config.hf_config, quant_config)
|
||||||
|
else:
|
||||||
|
model = model_class(model_config.hf_config)
|
||||||
|
if model_config.load_format == "dummy":
|
||||||
|
model = model.cuda()
|
||||||
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
|
# random values to the weights.
|
||||||
|
initialize_dummy_weights(model)
|
||||||
|
else:
|
||||||
|
# Load the weights from the cached or downloaded files.
|
||||||
|
model.load_weights(model_config.model, model_config.download_dir,
|
||||||
|
model_config.load_format, model_config.revision)
|
||||||
|
model = model.cuda()
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|||||||
@@ -1,17 +1,33 @@
|
|||||||
|
from vllm.model_executor.models.aquila import AquilaForCausalLM
|
||||||
|
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
|
||||||
|
BaichuanForCausalLM)
|
||||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||||
|
from vllm.model_executor.models.falcon import FalconForCausalLM
|
||||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||||
|
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||||
|
from vllm.model_executor.models.internlm import InternLMForCausalLM
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||||
|
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AquilaForCausalLM",
|
||||||
|
"BaiChuanForCausalLM",
|
||||||
|
"BaichuanForCausalLM",
|
||||||
"BloomForCausalLM",
|
"BloomForCausalLM",
|
||||||
|
"FalconForCausalLM",
|
||||||
"GPT2LMHeadModel",
|
"GPT2LMHeadModel",
|
||||||
"GPTBigCodeForCausalLM",
|
"GPTBigCodeForCausalLM",
|
||||||
|
"GPTJForCausalLM",
|
||||||
"GPTNeoXForCausalLM",
|
"GPTNeoXForCausalLM",
|
||||||
|
"InternLMForCausalLM",
|
||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"MPTForCausalLM",
|
"MPTForCausalLM",
|
||||||
"OPTForCausalLM",
|
"OPTForCausalLM",
|
||||||
|
"QWenLMHeadModel",
|
||||||
|
"MistralForCausalLM",
|
||||||
]
|
]
|
||||||
|
|||||||
369
vllm/model_executor/models/aquila.py
Normal file
369
vllm/model_executor/models/aquila.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
AquilaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
||||||
|
keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||||
|
self.variance_epsilon)
|
||||||
|
|
||||||
|
return (self.weight * hidden_states).to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: AquilaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.self_attn = AquilaAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_attention_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
)
|
||||||
|
self.mlp = AquilaMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: AquilaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
#vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = AquilaModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
|
self.config.num_attention_heads *
|
||||||
|
self.config.num_attention_heads // tp_size)
|
||||||
|
attention_weight_specs = [
|
||||||
|
# (weight_name, shard_size, offset)
|
||||||
|
("q_proj", q_proj_shard_size, 0),
|
||||||
|
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||||
|
("v_proj", kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size),
|
||||||
|
]
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
|
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
387
vllm/model_executor/models/baichuan.py
Normal file
387
vllm/model_executor/models/baichuan.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only BaiChuan model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||||
|
PagedAttentionWithALiBi)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||||
|
base = torch.tensor(
|
||||||
|
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != total_num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
num_remaining_heads = min(closest_power_of_2,
|
||||||
|
total_num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(start=1,
|
||||||
|
end=1 + 2 * num_remaining_heads,
|
||||||
|
step=2,
|
||||||
|
dtype=torch.int32)
|
||||||
|
slopes = torch.cat(
|
||||||
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
position_embedding: str,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||||
|
)
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = (self.total_num_heads //
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.postion_embedding = position_embedding
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
self.W_pack = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
3 * hidden_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
# Create the alibi slopes and slice them.
|
||||||
|
if self.postion_embedding == "ALIBI":
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
head_start = tp_rank * self.num_heads
|
||||||
|
head_end = (tp_rank + 1) * self.num_heads
|
||||||
|
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||||
|
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||||
|
|
||||||
|
scaling = self.head_dim**-0.5
|
||||||
|
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||||
|
scaling, alibi_slopes)
|
||||||
|
else:
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.W_pack(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
if self.postion_embedding == "ALIBI":
|
||||||
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||||
|
cache_event)
|
||||||
|
else:
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.self_attn = BaiChuanAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
position_embedding=position_embedding,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
)
|
||||||
|
self.mlp = BaiChuanMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
BaiChuanDecoderLayer(config, position_embedding)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanBaseForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config, position_embedding: str):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = BaiChuanModel(config, position_embedding)
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = []
|
||||||
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
|
if "W_pack" in name:
|
||||||
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
head_size = hidden_size // total_num_heads
|
||||||
|
num_heads = total_num_heads // tp_world_size
|
||||||
|
head_start = tp_rank * num_heads
|
||||||
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
|
||||||
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
|
head_size, hidden_size)
|
||||||
|
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||||
|
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||||
|
(tp_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tp_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config, "ALIBI")
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config, "ROPE")
|
||||||
@@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
|
|||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
@@ -279,15 +279,23 @@ class BloomForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if not name.startswith("transformer."):
|
if name == "lm_head.weight":
|
||||||
name = "transformer." + name
|
# Since hidden_states are parallelized, we need to
|
||||||
|
# load lm_head.weight in parallel.
|
||||||
|
self._column_parallel_weights.append(name)
|
||||||
|
# If lm_head is provided, use it instead.
|
||||||
|
param = self.lm_head_weight
|
||||||
|
else:
|
||||||
|
if not name.startswith("transformer."):
|
||||||
|
name = "transformer." + name
|
||||||
|
param = state_dict[name]
|
||||||
|
|
||||||
param = state_dict[name]
|
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
||||||
# [num_heads * 3 * head_size, hidden_size], while the
|
# [num_heads * 3 * head_size, hidden_size], while the
|
||||||
|
|||||||
504
vllm/model_executor/models/falcon.py
Normal file
504
vllm/model_executor/models/falcon.py
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Falcon model."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
from transformers import FalconConfig as HF_FalconConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||||
|
PagedAttentionWithALiBi,
|
||||||
|
PagedAttentionWithRoPE)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||||
|
reduce_from_tensor_model_parallel_region)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs import RWConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
|
||||||
|
# training, this means that there's one additional quantization to bfloat16
|
||||||
|
# between the operations. In order not to degrade the quality of our HF-port,
|
||||||
|
# we keep these characteristics in the final model.
|
||||||
|
class FalconLinear(nn.Linear):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = x @ self.weight.T
|
||||||
|
if self.bias is None:
|
||||||
|
return hidden_states
|
||||||
|
return hidden_states + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||||
|
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != total_num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32)
|
||||||
|
num_remaining_heads = min(closest_power_of_2,
|
||||||
|
total_num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(1,
|
||||||
|
1 + 2 * num_remaining_heads,
|
||||||
|
2,
|
||||||
|
dtype=torch.int32)
|
||||||
|
slopes = torch.cat(
|
||||||
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
class FalconAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
|
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||||
|
|
||||||
|
self.new_decoder_architecture = config.new_decoder_architecture
|
||||||
|
self.multi_query = config.multi_query
|
||||||
|
|
||||||
|
if self.new_decoder_architecture:
|
||||||
|
self.total_num_kv_heads = config.num_kv_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.query_key_value = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
)
|
||||||
|
elif self.multi_query:
|
||||||
|
self.total_num_kv_heads = 1
|
||||||
|
self.num_kv_heads = 1
|
||||||
|
self.query = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
)
|
||||||
|
self.key_value = FalconLinear(self.hidden_size,
|
||||||
|
2 * self.head_dim,
|
||||||
|
bias=config.bias)
|
||||||
|
else:
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
|
self.num_kv_heads = self.num_heads
|
||||||
|
self.query_key_value = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
|
# Layer-wise attention scaling
|
||||||
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
|
or config.parallel_attn)
|
||||||
|
self.dense = RowParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=config.bias,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
reduce_results=self.reduce_row_parallel_results)
|
||||||
|
|
||||||
|
self.use_rotary = config.rotary
|
||||||
|
self.use_alibi = config.alibi
|
||||||
|
assert not (self.use_rotary and self.use_alibi), (
|
||||||
|
"Rotary and alibi are mutually exclusive.")
|
||||||
|
|
||||||
|
if self.use_rotary:
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config,
|
||||||
|
"max_position_embeddings", 8192)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.inv_norm_factor,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
elif self.use_alibi:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
head_start = tp_rank * self.num_heads
|
||||||
|
head_end = (tp_rank + 1) * self.num_heads
|
||||||
|
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
|
||||||
|
self.inv_norm_factor)
|
||||||
|
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||||
|
self.attn = PagedAttentionWithALiBi(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.inv_norm_factor,
|
||||||
|
alibi_slopes,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
else:
|
||||||
|
self.attn = PagedAttention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
scale=self.inv_norm_factor,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not self.new_decoder_architecture and self.multi_query:
|
||||||
|
q, bias = self.query(hidden_states)
|
||||||
|
if bias is not None:
|
||||||
|
q += bias
|
||||||
|
kv = self.key_value(hidden_states)
|
||||||
|
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
|
||||||
|
else:
|
||||||
|
qkv, bias = self.query_key_value(hidden_states)
|
||||||
|
if bias is not None:
|
||||||
|
qkv += bias
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||||
|
dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
if self.use_rotary:
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
else:
|
||||||
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||||
|
cache_event)
|
||||||
|
attn_output, bias = self.dense(attn_output)
|
||||||
|
return attn_output, bias
|
||||||
|
|
||||||
|
|
||||||
|
class FalconMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||||
|
4 * hidden_size,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
|
or config.parallel_attn)
|
||||||
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
|
4 * hidden_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=config.bias,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
reduce_results=self.reduce_row_parallel_results)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
||||||
|
x, bias = self.dense_h_to_4h(x)
|
||||||
|
if bias is not None:
|
||||||
|
x += bias
|
||||||
|
x = self.act(x)
|
||||||
|
x, bias = self.dense_4h_to_h(x)
|
||||||
|
return x, bias
|
||||||
|
|
||||||
|
|
||||||
|
class FalconDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.self_attention = FalconAttention(config)
|
||||||
|
self.mlp = FalconMLP(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.new_decoder_architecture:
|
||||||
|
# The layer norm before self-attention
|
||||||
|
self.ln_attn = LayerNorm(hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
# The layer norm before the MLP
|
||||||
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
else:
|
||||||
|
self.input_layernorm = LayerNorm(hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
if not config.parallel_attn:
|
||||||
|
self.post_attention_layernorm = LayerNorm(
|
||||||
|
hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
|
or config.parallel_attn)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||||
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||||
|
else:
|
||||||
|
attention_layernorm_out = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self attention.
|
||||||
|
attention_output, attention_bias = self.self_attention(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=attention_layernorm_out,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||||
|
attention_output += attention_bias
|
||||||
|
|
||||||
|
if not self.config.new_decoder_architecture:
|
||||||
|
if self.config.parallel_attn:
|
||||||
|
mlp_layernorm_out = attention_layernorm_out
|
||||||
|
else:
|
||||||
|
residual += attention_output
|
||||||
|
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
|
||||||
|
if self.reduce_row_parallel_results and mlp_bias is not None:
|
||||||
|
mlp_output += mlp_bias
|
||||||
|
|
||||||
|
if not self.reduce_row_parallel_results:
|
||||||
|
# When MLP and Attention layers are parallel, we can use
|
||||||
|
# only one all-reduce operator to reduce the results from
|
||||||
|
# both MLP and Attention layers.
|
||||||
|
mlp_output += attention_output
|
||||||
|
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
|
||||||
|
if attention_bias is not None:
|
||||||
|
mlp_output += attention_bias
|
||||||
|
if mlp_bias is not None:
|
||||||
|
mlp_output += mlp_bias
|
||||||
|
|
||||||
|
output = mlp_output + residual
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FalconModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.use_alibi = config.alibi
|
||||||
|
|
||||||
|
# Embedding + LN Embedding
|
||||||
|
self.word_embeddings = VocabParallelEmbedding(
|
||||||
|
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
self.h = nn.ModuleList([
|
||||||
|
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Final Layer Norm
|
||||||
|
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
for i in range(len(self.h)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.h[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FalconForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.transformer = FalconModel(config)
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
input_metadata,
|
||||||
|
cache_events,
|
||||||
|
)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
|
||||||
|
"dense_h_to_4h.bias"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
tp_size = (get_tensor_model_parallel_world_size())
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
num_heads = total_num_heads // tp_size
|
||||||
|
head_size = hidden_size // total_num_heads
|
||||||
|
head_start = tp_rank * num_heads
|
||||||
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
total_num_kv_heads = self.config.num_kv_heads
|
||||||
|
num_kv_heads = total_num_kv_heads // tp_size
|
||||||
|
separated_q_kv = False
|
||||||
|
kv_head_start = tp_rank * num_kv_heads
|
||||||
|
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||||
|
elif self.config.multi_query:
|
||||||
|
total_num_kv_heads = 1
|
||||||
|
num_kv_heads = 1
|
||||||
|
separated_q_kv = True
|
||||||
|
kv_head_start = 0
|
||||||
|
kv_head_end = 1
|
||||||
|
else:
|
||||||
|
total_num_kv_heads = total_num_heads
|
||||||
|
num_kv_heads = total_num_kv_heads // tp_size
|
||||||
|
separated_q_kv = False
|
||||||
|
kv_head_start = tp_rank * num_kv_heads
|
||||||
|
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||||
|
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "query_key_value" in name:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
loaded_weight_size = loaded_weight.size()
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||||
|
head_size, *loaded_weight_size[1:])
|
||||||
|
|
||||||
|
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
|
||||||
|
wk = loaded_weight[:, [-2]].reshape(-1,
|
||||||
|
*loaded_weight_size[1:])
|
||||||
|
wv = loaded_weight[:, [-1]].reshape(-1,
|
||||||
|
*loaded_weight_size[1:])
|
||||||
|
|
||||||
|
wq = wq[head_size * head_start:head_size * head_end]
|
||||||
|
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
|
||||||
|
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
|
||||||
|
|
||||||
|
if separated_q_kv:
|
||||||
|
loaded_weight_q = wq
|
||||||
|
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||||
|
q_weight_name = name.replace("query_key_value", "query")
|
||||||
|
kv_weight_name = name.replace("query_key_value",
|
||||||
|
"key_value")
|
||||||
|
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||||
|
loaded_weight_q,
|
||||||
|
q_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank)
|
||||||
|
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||||
|
loaded_weight_kv,
|
||||||
|
kv_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights, tp_rank)
|
||||||
@@ -21,7 +21,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -31,13 +31,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (
|
||||||
load_tensor_parallel_weights)
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@@ -217,27 +218,28 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||||
_row_parallel_weights = ["c_proj.weight"]
|
_row_parallel_weights = ["c_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
@@ -250,6 +252,8 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
if not name.startswith("transformer."):
|
if not name.startswith("transformer."):
|
||||||
name = "transformer." + name
|
name = "transformer." + name
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||||
# Because of this, we need to transpose the weights.
|
# Because of this, we need to transpose the weights.
|
||||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||||
@@ -261,14 +265,9 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
if name == "transformer.wte.weight":
|
if name == "transformer.wte.weight":
|
||||||
# Consider padding in the vocab size.
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
padded_vocab_size = (param.shape[0] *
|
tensor_model_parallel_rank)
|
||||||
tensor_model_parallel_world_size)
|
continue
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
# For the fused QKV linear layer, manually shard the weights.
|
# For the fused QKV linear layer, manually shard the weights.
|
||||||
if "c_attn" in name:
|
if "c_attn" in name:
|
||||||
|
|||||||
@@ -22,24 +22,24 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
|
||||||
from transformers import GPTBigCodeConfig
|
from transformers import GPTBigCodeConfig
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (
|
||||||
load_tensor_parallel_weights)
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@@ -50,18 +50,36 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
total_num_heads = config.num_attention_heads
|
total_num_heads = config.num_attention_heads
|
||||||
tensor_model_parallel_world_size = (
|
self.tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
assert total_num_heads % self.tensor_model_parallel_world_size == 0
|
||||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
self.num_heads = (total_num_heads //
|
||||||
|
self.tensor_model_parallel_world_size)
|
||||||
self.head_dim = self.hidden_size // total_num_heads
|
self.head_dim = self.hidden_size // total_num_heads
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
self.multi_query = config.multi_query
|
||||||
3 * self.hidden_size,
|
if self.multi_query:
|
||||||
bias=True,
|
self.num_kv_heads = 1
|
||||||
gather_output=False,
|
self.kv_dim = self.head_dim
|
||||||
perform_initialization=False)
|
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.c_attn_kv = nn.Linear(self.hidden_size,
|
||||||
|
2 * self.kv_dim,
|
||||||
|
bias=True)
|
||||||
|
else:
|
||||||
|
self.num_kv_heads = self.num_heads
|
||||||
|
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||||
|
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||||
|
self.hidden_size +
|
||||||
|
2 * self.kv_dim,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
|
||||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
@@ -69,7 +87,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.attn = PagedAttention(self.num_heads,
|
self.attn = PagedAttention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale)
|
scale=self.scale,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -78,8 +97,17 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
if self.multi_query:
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, _ = self.c_attn_q(hidden_states)
|
||||||
|
kv = self.c_attn_kv(hidden_states)
|
||||||
|
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
||||||
|
else:
|
||||||
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
|
q, k, v = qkv.split([
|
||||||
|
self.hidden_size // self.tensor_model_parallel_world_size,
|
||||||
|
self.kv_dim, self.kv_dim
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||||
input_metadata, cache_event)
|
input_metadata, cache_event)
|
||||||
@@ -218,27 +246,28 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||||
_row_parallel_weights = ["c_proj.weight"]
|
_row_parallel_weights = ["c_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
@@ -248,51 +277,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
# NOTE: "c_attn.bias" should not be skipped.
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
|
||||||
|
|
||||||
if not name.startswith("transformer."):
|
if not name.startswith("transformer."):
|
||||||
name = "transformer." + name
|
name = "transformer." + name
|
||||||
|
|
||||||
if name == "transformer.wte.weight":
|
|
||||||
# Consider padding in the vocab size.
|
|
||||||
padded_vocab_size = param.shape[
|
|
||||||
0] * tensor_model_parallel_world_size
|
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
|
||||||
"""manipulates along axis=0 from MQA to MHA
|
|
||||||
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
|
||||||
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
|
|
||||||
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
|
|
||||||
|
|
||||||
TODO: this function is no longer needed once vllm supports MQA.
|
|
||||||
"""
|
|
||||||
qkv_array = qkv_array.numpy()
|
|
||||||
|
|
||||||
dims_q = n_head * head_dim
|
|
||||||
# pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
|
|
||||||
axis=0)
|
|
||||||
# q is fine, but k & v have not replicated shape along the first
|
|
||||||
# axis as long as MQA is not nativly supported, increase memory
|
|
||||||
# and replicated (head_dim, hidden_dim) to
|
|
||||||
# (n_heads * head_dim, hidden_dim)
|
|
||||||
if k.ndim == 2 and v.ndim == 2:
|
|
||||||
replication = (n_head, 1) # weights
|
|
||||||
else:
|
|
||||||
replication = n_head # biases
|
|
||||||
# replicate n_head times for q, v
|
|
||||||
k, v = np.tile(k, replication), np.tile(v, replication)
|
|
||||||
# concat q, k, v along the first axis
|
|
||||||
# (n_heads * head_dim, hidden_dim)
|
|
||||||
# to (3 * n_heads * head_dim, hidden_dim)
|
|
||||||
qkv_array = np.concatenate((q, k, v), axis=0)
|
|
||||||
return torch.from_numpy(qkv_array)
|
|
||||||
|
|
||||||
# For the fused QKV linear layer, manually shard the weights.
|
# For the fused QKV linear layer, manually shard the weights.
|
||||||
if "c_attn" in name:
|
if "c_attn" in name:
|
||||||
# GPT-2's fused QKV has the shape of
|
# GPT-2's fused QKV has the shape of
|
||||||
@@ -300,30 +287,53 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
# When tensor parallelism is used, we shard the weights along
|
# When tensor parallelism is used, we shard the weights along
|
||||||
# the head dimension.
|
# the head dimension.
|
||||||
total_num_heads = self.config.num_attention_heads
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
total_num_kv_heads = (1 if self.config.multi_query else
|
||||||
|
total_num_heads)
|
||||||
hidden_size = self.config.hidden_size
|
hidden_size = self.config.hidden_size
|
||||||
head_size = hidden_size // total_num_heads
|
head_size = hidden_size // total_num_heads
|
||||||
|
total_kv_size = head_size * total_num_kv_heads
|
||||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||||
head_start = tensor_model_parallel_rank * num_heads
|
head_start = tensor_model_parallel_rank * num_heads
|
||||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||||
|
|
||||||
if name.endswith(".weight"):
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
wq, wk, wv = torch.split(
|
||||||
n_head=total_num_heads,
|
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||||
head_dim=head_size)
|
dim=0)
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
|
||||||
head_size, hidden_size)
|
wq = wq[head_size * head_start:head_size * head_end]
|
||||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
if not self.config.multi_query:
|
||||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
# Split the heads when using normal multi-head attention
|
||||||
elif name.endswith(".bias"):
|
wk = wk[head_size * head_start:head_size * head_end]
|
||||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
wv = wv[head_size * head_start:head_size * head_end]
|
||||||
n_head=total_num_heads,
|
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||||
head_dim=head_size)
|
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
|
||||||
head_size)
|
|
||||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected parameter name {name}")
|
# For multi-query attention, we split the query
|
||||||
|
# but replicate the key and value.
|
||||||
|
loaded_weight_q = wq
|
||||||
|
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||||
|
q_weight_name = name.replace("c_attn", "c_attn_q")
|
||||||
|
kv_weight_name = name.replace("c_attn", "c_attn_kv")
|
||||||
|
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||||
|
loaded_weight_q,
|
||||||
|
q_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||||
|
loaded_weight_kv,
|
||||||
|
kv_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if name == "transformer.wte.weight":
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
|
|||||||
261
vllm/model_executor/models/gpt_j.py
Normal file
261
vllm/model_executor/models/gpt_j.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only GPT-J model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import GPTJConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: GPTJConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.total_num_heads
|
||||||
|
|
||||||
|
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
|
||||||
|
3 * config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert self.total_num_heads % tp_world_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_world_size
|
||||||
|
|
||||||
|
scaling = self.head_size**-0.5
|
||||||
|
assert getattr(config, "rotary", True)
|
||||||
|
assert config.rotary_dim % 2 == 0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
config.rotary_dim,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
is_neox_style=False)
|
||||||
|
self.warmup = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
attn_output, _ = self.out_proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, intermediate_size: int, config: GPTJConfig):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.n_embd
|
||||||
|
self.fc_in = ColumnParallelLinear(hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.fc_out = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.act = get_act_fn(config.activation_function)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.fc_in(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states, _ = self.fc_out(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: GPTJConfig):
|
||||||
|
super().__init__()
|
||||||
|
if config.n_inner is None:
|
||||||
|
inner_dim = 4 * config.n_embd
|
||||||
|
else:
|
||||||
|
inner_dim = config.n_inner
|
||||||
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
self.attn = GPTJAttention(config)
|
||||||
|
self.mlp = GPTJMLP(inner_dim, config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ln_1(hidden_states)
|
||||||
|
attn_output = self.attn(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
hidden_states = attn_output + mlp_output + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: GPTJConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.n_embd
|
||||||
|
self.wte = VocabParallelEmbedding(config.vocab_size,
|
||||||
|
self.embed_dim,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[GPTJBlock(config) for _ in range(config.n_layer)])
|
||||||
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.wte(input_ids)
|
||||||
|
for i in range(len(self.h)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.h[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
position_ids,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: GPTJConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
assert not config.tie_word_embeddings
|
||||||
|
self.transformer = GPTJModel(config)
|
||||||
|
self.lm_head = ColumnParallelLinear(config.n_embd,
|
||||||
|
config.vocab_size,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata, self.lm_head.bias)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
|
||||||
|
"lm_head.bias"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for stride_id, att_weight_name in enumerate(
|
||||||
|
["q_proj", "k_proj", "v_proj"]):
|
||||||
|
if att_weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||||
|
shard_size = param.shape[1]
|
||||||
|
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||||
|
(tp_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights, tp_rank)
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||||
assert rotary_dim % 2 == 0
|
assert rotary_dim % 2 == 0
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
scaling, rotary_dim)
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
rotary_dim,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -215,7 +223,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||||
@@ -231,11 +239,12 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||||
or "rotary_emb.inv_freq" in name):
|
or "rotary_emb.inv_freq" in name):
|
||||||
continue
|
continue
|
||||||
|
|||||||
305
vllm/model_executor/models/internlm.py
Normal file
305
vllm/model_executor/models/internlm.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tensor_model_parallel_world_size = (
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = (self.total_num_heads //
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
3 * self.total_num_heads * self.head_dim,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
rotary_dim=self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.self_attn = InternLMAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
)
|
||||||
|
self.mlp = InternLMMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
InternLMDecoderLayer(config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = InternLMModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
param = state_dict[name]
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for stride_id, att_weight_name in enumerate(
|
||||||
|
["q_proj", "k_proj", "v_proj"]):
|
||||||
|
if att_weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||||
|
shard_size = param.shape[0] // 3
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
@@ -25,7 +25,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -36,13 +36,16 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||||
load_tensor_parallel_weights)
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@@ -54,18 +57,21 @@ class LlamaMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||||
2 * intermediate_size,
|
2 * intermediate_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False)
|
perform_initialization=False,
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
quant_config=quant_config)
|
||||||
hidden_size,
|
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||||
bias=False,
|
hidden_size,
|
||||||
input_is_parallel=True,
|
bias=False,
|
||||||
perform_initialization=False)
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
@@ -84,36 +90,54 @@ class LlamaAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
):
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tensor_model_parallel_world_size = (
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
get_tensor_model_parallel_world_size())
|
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = (self.total_num_heads //
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
tensor_model_parallel_world_size)
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ParallelLinear.column(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3 * self.total_num_heads * self.head_dim,
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = ParallelLinear.row(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim)
|
self.scaling,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -124,7 +148,7 @@ class LlamaAttention(nn.Module):
|
|||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
input_metadata, cache_event)
|
input_metadata, cache_event)
|
||||||
@@ -134,17 +158,32 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
# Requires transformers > 4.32.0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@@ -181,18 +220,22 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
config.hidden_size,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
LlamaDecoderLayer(config, quant_config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@@ -224,15 +267,23 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = LlamaModel(config)
|
self.quant_config = quant_config
|
||||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
self.model = LlamaModel(config, quant_config)
|
||||||
config.vocab_size,
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
bias=False,
|
# NOTE: The LM head is not quantized.
|
||||||
gather_output=False,
|
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||||
perform_initialization=False)
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=None)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -242,44 +293,82 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = [
|
_column_parallel_layers = []
|
||||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||||
"gate_proj.weight", "up_proj.weight"
|
|
||||||
]
|
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
if self.quant_config is None:
|
||||||
|
weight_suffixes = ["weight"]
|
||||||
|
else:
|
||||||
|
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||||
|
|
||||||
|
column_parallel_weights: List[str] = []
|
||||||
|
for layer in self._column_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
row_parallel_weights: List[str] = []
|
||||||
|
for layer in self._row_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
|
self.config.num_attention_heads *
|
||||||
|
self.config.num_key_value_heads // tp_size)
|
||||||
|
attention_weight_specs = [
|
||||||
|
# (weight_name, shard_size, offset)
|
||||||
|
("q_proj", q_proj_shard_size, 0),
|
||||||
|
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||||
|
("v_proj", kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size),
|
||||||
|
]
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
is_packed = False
|
||||||
|
is_transposed = False
|
||||||
|
if self.quant_config is not None:
|
||||||
|
is_packed = self.quant_config.is_packed(name)
|
||||||
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
|
if is_transposed:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
|
||||||
is_attention_weight = False
|
is_attention_weight = False
|
||||||
for stride_id, att_weight_name in enumerate(
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
["q_proj", "k_proj", "v_proj"]):
|
if weight_name not in name:
|
||||||
if att_weight_name not in name:
|
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
shard_size = param.shape[0] // 3
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
shard_size //= self.quant_config.pack_factor
|
||||||
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
(tensor_model_parallel_rank + 1)]
|
(tensor_model_parallel_rank + 1)]
|
||||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
(stride_id + 1)]
|
|
||||||
assert param_slice.shape == loaded_weight.shape
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
param_slice.copy_(loaded_weight)
|
param_slice.copy_(loaded_weight)
|
||||||
is_attention_weight = True
|
is_attention_weight = True
|
||||||
break
|
break
|
||||||
@@ -291,6 +380,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
shard_size = param.shape[0] // 2
|
shard_size = param.shape[0] // 2
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
@@ -305,7 +397,15 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
column_parallel_weights,
|
||||||
self._row_parallel_weights,
|
row_parallel_weights,
|
||||||
tensor_model_parallel_rank)
|
tensor_model_parallel_rank)
|
||||||
|
|||||||
404
vllm/model_executor/models/mistral.py
Normal file
404
vllm/model_executor/models/mistral.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.mistral import MistralConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class MistralMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MistralAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 4096 * 32,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
sliding_window: Optional[int] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
self.qkv_proj = ParallelLinear.column(
|
||||||
|
hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.o_proj = ParallelLinear.row(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
base=self.rope_theta,
|
||||||
|
max_position=max_position,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
sliding_window=self.sliding_window)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MistralDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
# Requires transformers > 4.32.0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
self.self_attn = MistralAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
quant_config=quant_config,
|
||||||
|
sliding_window=config.sliding_window)
|
||||||
|
self.mlp = MistralMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MistralModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
MistralDecoderLayer(config, quant_config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = MistralModel(config, quant_config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
# NOTE: The LM head is not quantized.
|
||||||
|
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=None)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_layers = []
|
||||||
|
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
if self.quant_config is None:
|
||||||
|
weight_suffixes = ["weight"]
|
||||||
|
else:
|
||||||
|
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||||
|
|
||||||
|
column_parallel_weights: List[str] = []
|
||||||
|
for layer in self._column_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
row_parallel_weights: List[str] = []
|
||||||
|
for layer in self._row_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
|
self.config.num_attention_heads *
|
||||||
|
self.config.num_key_value_heads // tp_size)
|
||||||
|
attention_weight_specs = [
|
||||||
|
# (weight_name, shard_size, offset)
|
||||||
|
("q_proj", q_proj_shard_size, 0),
|
||||||
|
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||||
|
("v_proj", kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size),
|
||||||
|
]
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_packed = False
|
||||||
|
is_transposed = False
|
||||||
|
if self.quant_config is not None:
|
||||||
|
is_packed = self.quant_config.is_packed(name)
|
||||||
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
|
if is_transposed:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
shard_size //= self.quant_config.pack_factor
|
||||||
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
column_parallel_weights,
|
||||||
|
row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
|
# coding=utf-8
|
||||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -9,13 +10,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
@@ -229,7 +231,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
@@ -242,12 +244,13 @@ class MPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "Wqkv" in name:
|
if "Wqkv" in name:
|
||||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||||
# [3 * num_heads * head_size, hidden_size].
|
# [3 * num_heads * head_size, hidden_size].
|
||||||
@@ -259,7 +262,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
num_heads = total_num_heads // tp_world_size
|
num_heads = total_num_heads // tp_world_size
|
||||||
head_start = tp_rank * num_heads
|
head_start = tp_rank * num_heads
|
||||||
head_end = (tp_rank + 1) * num_heads
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
if name.endswith(".weight"):
|
if name.endswith(".weight"):
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
head_size, hidden_size)
|
head_size, hidden_size)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
@@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
327
vllm/model_executor/models/qwen.py
Normal file
327
vllm/model_executor/models/qwen.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
|
||||||
|
# Copyright (c) Alibaba Cloud.
|
||||||
|
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||||
|
"""Inference-only QWen model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab,
|
||||||
|
load_tensor_parallel_weights,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class QWenMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.c_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QWenAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||||
|
)
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = (self.total_num_heads //
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
self.c_attn = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
3 * hidden_size,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.c_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
|
||||||
|
output, _ = self.c_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QWenBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: QWenConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
self.attn = QWenAttention(config.hidden_size,
|
||||||
|
config.num_attention_heads,
|
||||||
|
config.max_position_embeddings,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ln_1(hidden_states)
|
||||||
|
hidden_states = self.attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ln_2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QWenModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: QWenConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.wte = VocabParallelEmbedding(vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.wte(input_ids)
|
||||||
|
for i in range(len(self.h)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.h[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QWenLMHeadModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: QWenConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.transformer = QWenModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.lm_head = ColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = []
|
||||||
|
_row_parallel_weights = ["c_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
):
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
|
if "c_attn" in name:
|
||||||
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
head_size = hidden_size // total_num_heads
|
||||||
|
num_heads = total_num_heads // tp_world_size
|
||||||
|
head_start = tp_rank * num_heads
|
||||||
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
|
||||||
|
if "weight" in name:
|
||||||
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
|
head_size, hidden_size)
|
||||||
|
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||||
|
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||||
|
elif "bias" in name:
|
||||||
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
|
head_size)
|
||||||
|
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||||
|
loaded_weight = loaded_weight.reshape(-1)
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["w2", "w1"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||||
|
(tp_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if "wte" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tp_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank,
|
||||||
|
)
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
import vllm.model_executor.parallel_utils.parallel_state
|
import vllm.model_executor.parallel_utils.parallel_state
|
||||||
import vllm.model_executor.parallel_utils.tensor_parallel
|
import vllm.model_executor.parallel_utils.tensor_parallel
|
||||||
|
|
||||||
# Alias parallel_state as mpu, its legacy name
|
|
||||||
mpu = parallel_state
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"parallel_state",
|
"parallel_state",
|
||||||
"tensor_parallel",
|
"tensor_parallel",
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
|
|||||||
# rank when broadcasting weights from src to all other data parallel ranks
|
# rank when broadcasting weights from src to all other data parallel ranks
|
||||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||||
|
|
||||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
@@ -196,20 +195,6 @@ def initialize_model_parallel(
|
|||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||||
|
|
||||||
def initialize_all_reduce_launcher(
|
|
||||||
max_num_tokens: int,
|
|
||||||
hidden_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
disable_graph: bool = False,
|
|
||||||
) -> None:
|
|
||||||
global _ALL_REDUCE_LAUNCHER
|
|
||||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
|
||||||
max_num_tokens=max_num_tokens,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_graph=disable_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
def model_parallel_is_initialized():
|
def model_parallel_is_initialized():
|
||||||
"""Check if model and data parallel groups are initialized."""
|
"""Check if model and data parallel groups are initialized."""
|
||||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||||
@@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
|
|||||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_next_rank():
|
def get_pipeline_model_parallel_next_rank():
|
||||||
"""Return the global rank that follows the caller in the pipeline"""
|
"""Return the global rank that follows the caller in the pipeline"""
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||||
@@ -485,10 +471,6 @@ def get_data_parallel_rank():
|
|||||||
"""Return my rank for the data parallel group."""
|
"""Return my rank for the data parallel group."""
|
||||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||||
|
|
||||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
|
||||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
|
||||||
return _ALL_REDUCE_LAUNCHER
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none."""
|
"""Set the groups to none."""
|
||||||
global _MODEL_PARALLEL_GROUP
|
global _MODEL_PARALLEL_GROUP
|
||||||
@@ -515,56 +497,3 @@ def destroy_model_parallel():
|
|||||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||||
|
|
||||||
|
|
||||||
class GraphAllReduce:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_num_tokens: int,
|
|
||||||
hidden_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
disable_graph: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.max_num_tokens = max_num_tokens
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.disable_graph = disable_graph
|
|
||||||
|
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
if tp_world_size == 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.group = get_tensor_model_parallel_group()
|
|
||||||
self.buffer = torch.empty(
|
|
||||||
size=(max_num_tokens, hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda',
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build graphs for different number of tokens.
|
|
||||||
if not self.disable_graph:
|
|
||||||
self.graphs = {}
|
|
||||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
|
||||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
|
||||||
|
|
||||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
|
||||||
# Warm up.
|
|
||||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Build graph.
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
self.buffer[:num_tokens], group=self.group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return graph
|
|
||||||
|
|
||||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# NOTE: x must be a slice of self.buffer.
|
|
||||||
num_tokens = x.shape[0]
|
|
||||||
if self.disable_graph:
|
|
||||||
torch.distributed.all_reduce(x, group=self.group)
|
|
||||||
else:
|
|
||||||
self.graphs[num_tokens].replay()
|
|
||||||
return x
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from .mappings import (
|
|||||||
copy_to_tensor_model_parallel_region,
|
copy_to_tensor_model_parallel_region,
|
||||||
gather_from_tensor_model_parallel_region,
|
gather_from_tensor_model_parallel_region,
|
||||||
gather_from_sequence_parallel_region,
|
gather_from_sequence_parallel_region,
|
||||||
|
reduce_from_tensor_model_parallel_region,
|
||||||
scatter_to_tensor_model_parallel_region,
|
scatter_to_tensor_model_parallel_region,
|
||||||
scatter_to_sequence_parallel_region,
|
scatter_to_sequence_parallel_region,
|
||||||
)
|
)
|
||||||
@@ -38,7 +39,7 @@ __all__ = [
|
|||||||
"copy_to_tensor_model_parallel_region",
|
"copy_to_tensor_model_parallel_region",
|
||||||
"gather_from_tensor_model_parallel_region",
|
"gather_from_tensor_model_parallel_region",
|
||||||
"gather_from_sequence_parallel_region",
|
"gather_from_sequence_parallel_region",
|
||||||
# "reduce_from_tensor_model_parallel_region",
|
"reduce_from_tensor_model_parallel_region",
|
||||||
"scatter_to_tensor_model_parallel_region",
|
"scatter_to_tensor_model_parallel_region",
|
||||||
"scatter_to_sequence_parallel_region",
|
"scatter_to_sequence_parallel_region",
|
||||||
# random.py
|
# random.py
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
# Parts of the code here are adapted from PyTorch
|
# Parts of the code here are adapted from PyTorch
|
||||||
# repo: https://github.com/pytorch/pytorch
|
# repo: https://github.com/pytorch/pytorch
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -14,16 +14,13 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_all_reduce_launcher,
|
|
||||||
)
|
)
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
copy_to_tensor_model_parallel_region,
|
|
||||||
gather_from_tensor_model_parallel_region,
|
gather_from_tensor_model_parallel_region,
|
||||||
reduce_from_tensor_model_parallel_region,
|
reduce_from_tensor_model_parallel_region,
|
||||||
scatter_to_tensor_model_parallel_region,
|
scatter_to_tensor_model_parallel_region,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .random import get_cuda_rng_tracker
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
divide,
|
divide,
|
||||||
VocabUtility,
|
VocabUtility,
|
||||||
@@ -66,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
|||||||
maybe_copy(attribute)
|
maybe_copy(attribute)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_gpu(weight, init_method,
|
|
||||||
partition_dim, stride=1):
|
|
||||||
"""Initialize affine weight for model parallel on GPU."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
with get_cuda_rng_tracker().fork():
|
|
||||||
init_method(weight)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
|
||||||
per_partition_size, partition_dim,
|
|
||||||
init_method, stride=1,
|
|
||||||
return_master_weight=False,
|
|
||||||
*, params_dtype=None):
|
|
||||||
"""Initialize affine weight for model parallel.
|
|
||||||
|
|
||||||
Build the master weight on all processes and scatter
|
|
||||||
the relevant chunk."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Initialize master weight
|
|
||||||
master_weight = torch.empty(output_size, input_size,
|
|
||||||
dtype=torch.float,
|
|
||||||
requires_grad=False)
|
|
||||||
init_method(master_weight)
|
|
||||||
master_weight = master_weight.to(dtype=params_dtype)
|
|
||||||
|
|
||||||
# Split and copy
|
|
||||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
|
||||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
|
||||||
dim=partition_dim)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
my_weight_list = weight_list[rank::world_size]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
|
||||||
if return_master_weight:
|
|
||||||
return master_weight
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VocabParallelEmbedding(torch.nn.Module):
|
class VocabParallelEmbedding(torch.nn.Module):
|
||||||
"""Embedding parallelized in the vocabulary dimension.
|
"""Embedding parallelized in the vocabulary dimension.
|
||||||
|
|
||||||
@@ -139,8 +83,11 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
init_method=init.xavier_normal_,
|
init_method=init.xavier_normal_,
|
||||||
params_dtype: torch.dtype=None,
|
params_dtype: torch.dtype=None,
|
||||||
use_cpu_initialization: bool=False,
|
use_cpu_initialization: bool=False,
|
||||||
perform_initialization: bool=True):
|
perform_initialization: bool=False):
|
||||||
super(VocabParallelEmbedding, self).__init__()
|
super(VocabParallelEmbedding, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep the input dimensions.
|
# Keep the input dimensions.
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
@@ -163,23 +110,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||||
self.vocab_start_index
|
self.vocab_start_index
|
||||||
|
|
||||||
# Allocate weights and initialize.
|
self.weight = Parameter(torch.empty(
|
||||||
if use_cpu_initialization:
|
self.num_embeddings_per_partition, self.embedding_dim,
|
||||||
self.weight = Parameter(torch.empty(
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.num_embeddings, self.embedding_dim,
|
|
||||||
self.num_embeddings_per_partition, 0, init_method,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=1)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
if self.tensor_model_parallel_size > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
@@ -239,18 +172,22 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=False,
|
||||||
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(ColumnParallelLinear, self).__init__()
|
super(ColumnParallelLinear, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.output_size_per_partition = divide(output_size, world_size)
|
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@@ -258,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
# Parameters.
|
# Parameters.
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||||
# we allocate the transpose.
|
# we allocate the transpose.
|
||||||
# Initialize weight.
|
self.create_weights(params_dtype)
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.output_size_per_partition, 0, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition, self.input_size,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=stride)
|
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
if use_cpu_initialization:
|
self.bias = Parameter(torch.empty(
|
||||||
self.bias = Parameter(torch.empty(
|
self.output_size_per_partition,
|
||||||
self.output_size_per_partition, dtype=params_dtype))
|
device=torch.cuda.current_device(),
|
||||||
else:
|
dtype=params_dtype))
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -292,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
self.weight = Parameter(torch.empty(
|
||||||
|
self.output_size_per_partition, self.input_size,
|
||||||
|
device=torch.cuda.current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight, bias)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
@@ -307,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
|
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
output_parallel = self.apply_weights(input_parallel, bias)
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||||
@@ -350,6 +278,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
params_dtype:
|
params_dtype:
|
||||||
use_cpu_initialization:
|
use_cpu_initialization:
|
||||||
perform_initialization:
|
perform_initialization:
|
||||||
|
reduce_results:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_size, output_size, *,
|
def __init__(self, input_size, output_size, *,
|
||||||
@@ -359,58 +288,52 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=False,
|
||||||
|
reduce_results=True,
|
||||||
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(RowParallelLinear, self).__init__()
|
super(RowParallelLinear, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.input_is_parallel = input_is_parallel
|
self.input_is_parallel = input_is_parallel
|
||||||
|
self.reduce_results = reduce_results
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, world_size)
|
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.create_weights(params_dtype)
|
||||||
|
|
||||||
|
if not reduce_results and (bias and not skip_bias_add):
|
||||||
|
raise ValueError("When not reduce the results, adding bias to the "
|
||||||
|
"results can lead to incorrect results")
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
||||||
# we allocate the transpose.
|
|
||||||
# Initialize weight.
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size,
|
|
||||||
self.input_size_per_partition,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.input_size_per_partition, 1, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size, self.input_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=1, stride=stride)
|
|
||||||
if bias:
|
if bias:
|
||||||
if use_cpu_initialization:
|
self.bias = Parameter(torch.empty(
|
||||||
self.bias = Parameter(torch.empty(self.output_size,
|
self.output_size, device=torch.cuda.current_device(),
|
||||||
dtype=params_dtype))
|
dtype=params_dtype))
|
||||||
else:
|
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size, device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
|
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.bias.zero_()
|
self.bias.zero_()
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
self.weight_t = self.weight.t()
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
self.weight = Parameter(torch.empty(
|
||||||
|
self.output_size, self.input_size_per_partition,
|
||||||
|
device=torch.cuda.current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
@@ -427,17 +350,12 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||||
if get_tensor_model_parallel_world_size() == 1:
|
# Matrix multiply.
|
||||||
# Matrix multiply.
|
output_parallel = self.apply_weights(input_parallel)
|
||||||
output_ = F.linear(input_parallel, self.weight)
|
if self.reduce_results and self.world_size > 1:
|
||||||
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||||
else:
|
else:
|
||||||
# Matrix multiply.
|
output_ = output_parallel
|
||||||
all_reduce_launcher = get_all_reduce_launcher()
|
|
||||||
num_tokens = input_parallel.shape[0]
|
|
||||||
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
|
||||||
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
|
||||||
# All-reduce across all the partitions.
|
|
||||||
output_ = all_reduce_launcher.launch(output_buffer)
|
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
output = output_ + self.bias if self.bias is not None else output_
|
||||||
|
|||||||
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
_QUANTIZATION_REGISTRY = {
|
||||||
|
"awq": AWQConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
|
||||||
|
if quantization not in _QUANTIZATION_REGISTRY:
|
||||||
|
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||||
|
return _QUANTIZATION_REGISTRY[quantization]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QuantizationConfig",
|
||||||
|
"get_quant_class",
|
||||||
|
]
|
||||||
72
vllm/model_executor/quantization_utils/awq.py
Normal file
72
vllm/model_executor/quantization_utils/awq.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AWQConfig(QuantizationConfig):
|
||||||
|
"""Config class for AWQ.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2306.00978
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
zero_point: bool,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.zero_point = zero_point
|
||||||
|
|
||||||
|
if self.weight_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 4-bit weight quantization is supported for "
|
||||||
|
f"AWQ, but got {self.weight_bits} bits.")
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||||
|
f"group_size={self.group_size}, "
|
||||||
|
f"zero_point={self.zero_point})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "awq"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# The AWQ kernel only supports Ampere or newer GPUs.
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return [
|
||||||
|
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||||
|
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||||
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||||
|
return cls(weight_bits, group_size, zero_point)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tp_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
||||||
75
vllm/model_executor/quantization_utils/base.py
Normal file
75
vllm/model_executor/quantization_utils/base.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationConfig:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
"""Name of the quantization method."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
"""List of supported activation dtypes."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
"""Minimum GPU capability to support the quantization method.
|
||||||
|
|
||||||
|
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||||
|
This requirement is due to the custom CUDA kernels used by the
|
||||||
|
quantization method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
"""List of filenames to search for in the model directory."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||||
|
"""Create a config class from the model's quantization config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||||
|
"""Get a value from the model's quantization config."""
|
||||||
|
for key in keys:
|
||||||
|
if key in config:
|
||||||
|
return config[key]
|
||||||
|
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||||
|
"quantization config.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_packed(cls, tensor_name: str) -> bool:
|
||||||
|
"""Returns True if a tensor is packed.
|
||||||
|
|
||||||
|
A tensor is considered packed if each element in the tensor is a
|
||||||
|
packed representation of multiple elements in the original tensor.
|
||||||
|
For example, an INT32 element in the tensor may represent 8 INT4
|
||||||
|
elements in the original tensor.
|
||||||
|
"""
|
||||||
|
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_transposed(cls, tensor_name: str) -> bool:
|
||||||
|
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
|
||||||
|
"""
|
||||||
|
return any(tag in tensor_name
|
||||||
|
for tag in cls.get_transposed_tensor_names())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tp_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -3,13 +3,21 @@ import filelock
|
|||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Iterator, List, Optional, Tuple
|
from collections import defaultdict
|
||||||
|
from typing import Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.quantization_utils import get_quant_class
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Disabledtqdm(tqdm):
|
class Disabledtqdm(tqdm):
|
||||||
|
|
||||||
@@ -17,40 +25,186 @@ class Disabledtqdm(tqdm):
|
|||||||
super().__init__(*args, **kwargs, disable=True)
|
super().__init__(*args, **kwargs, disable=True)
|
||||||
|
|
||||||
|
|
||||||
def hf_model_weights_iterator(
|
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||||
model_name_or_path: str,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
use_np_cache: bool = False,
|
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
||||||
# Prepare file lock directory to prevent multiple processes from
|
|
||||||
# downloading the same model weights at the same time.
|
|
||||||
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||||
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
||||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
||||||
|
return lock
|
||||||
|
|
||||||
# Download model weights from huggingface.
|
|
||||||
|
def _shared_pointers(tensors):
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for k, v in tensors.items():
|
||||||
|
ptrs[v.data_ptr()].append(k)
|
||||||
|
failing = []
|
||||||
|
for _, names in ptrs.items():
|
||||||
|
if len(names) > 1:
|
||||||
|
failing.append(names)
|
||||||
|
return failing
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bin_to_safetensor_file(
|
||||||
|
pt_filename: str,
|
||||||
|
sf_filename: str,
|
||||||
|
) -> None:
|
||||||
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
|
if "state_dict" in loaded:
|
||||||
|
loaded = loaded["state_dict"]
|
||||||
|
shared = _shared_pointers(loaded)
|
||||||
|
for shared_weights in shared:
|
||||||
|
for name in shared_weights[1:]:
|
||||||
|
loaded.pop(name)
|
||||||
|
|
||||||
|
# For tensors to be contiguous
|
||||||
|
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||||
|
|
||||||
|
dirname = os.path.dirname(sf_filename)
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
# check file size
|
||||||
|
sf_size = os.stat(sf_filename).st_size
|
||||||
|
pt_size = os.stat(pt_filename).st_size
|
||||||
|
if (sf_size - pt_size) / pt_size > 0.01:
|
||||||
|
raise RuntimeError(f"""The file size different is more than 1%:
|
||||||
|
- {sf_filename}: {sf_size}
|
||||||
|
- {pt_filename}: {pt_size}
|
||||||
|
""")
|
||||||
|
|
||||||
|
# check if the tensors are the same
|
||||||
|
reloaded = load_file(sf_filename)
|
||||||
|
for k in loaded:
|
||||||
|
pt_tensor = loaded[k]
|
||||||
|
sf_tensor = reloaded[k]
|
||||||
|
if not torch.equal(pt_tensor, sf_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(woosuk): Move this to other place.
|
||||||
|
def get_quant_config(
|
||||||
|
quantization: str,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> QuantizationConfig:
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
if not is_local:
|
if not is_local:
|
||||||
with lock:
|
# Download the config files.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
allow_patterns="*.bin",
|
allow_patterns="*.json",
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
tqdm_class=Disabledtqdm)
|
tqdm_class=Disabledtqdm)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||||
|
|
||||||
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
|
quant_cls = get_quant_class(quantization)
|
||||||
|
quant_config_files = [
|
||||||
|
f for f in config_files if any(
|
||||||
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
|
]
|
||||||
|
if len(quant_config_files) == 0:
|
||||||
|
raise ValueError(f"Cannot find the config file for {quantization}")
|
||||||
|
if len(quant_config_files) > 1:
|
||||||
|
raise ValueError(f"Found multiple config files for {quantization}: "
|
||||||
|
f"{quant_config_files}")
|
||||||
|
|
||||||
|
quant_config_file = quant_config_files[0]
|
||||||
|
with open(quant_config_file, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_hf_model_weights(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
fall_back_to_pt: bool = True,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> Tuple[str, List[str], bool]:
|
||||||
|
# Download model weights from huggingface.
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
if use_safetensors:
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
else:
|
||||||
|
# Some quantized models use .pt files for storing the weights.
|
||||||
|
allow_patterns = ["*.bin", "*.pt"]
|
||||||
|
if not is_local:
|
||||||
|
# Use file lock to prevent multiple processes from
|
||||||
|
# downloading the same model weights at the same time.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
tqdm_class=Disabledtqdm,
|
||||||
|
revision=revision)
|
||||||
|
else:
|
||||||
|
hf_folder = model_name_or_path
|
||||||
|
hf_weights_files: List[str] = []
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
|
if not use_safetensors:
|
||||||
|
hf_weights_files = [
|
||||||
|
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||||
|
return prepare_hf_model_weights(model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
use_safetensors=False,
|
||||||
|
fall_back_to_pt=False,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
|
|
||||||
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
def hf_model_weights_iterator(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
use_safetensors = False
|
||||||
|
use_np_cache = False
|
||||||
|
fall_back_to_pt = False
|
||||||
|
if load_format == "auto":
|
||||||
|
use_safetensors = True
|
||||||
|
fall_back_to_pt = True
|
||||||
|
elif load_format == "safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
elif load_format == "pt":
|
||||||
|
pass
|
||||||
|
elif load_format == "npcache":
|
||||||
|
use_np_cache = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
|
||||||
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||||
|
model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
fall_back_to_pt=fall_back_to_pt,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
if use_np_cache:
|
if use_np_cache:
|
||||||
|
# Currently np_cache only support *.bin checkpoints
|
||||||
|
assert use_safetensors is False
|
||||||
|
|
||||||
# Convert the model weights from torch tensors to numpy arrays for
|
# Convert the model weights from torch tensors to numpy arrays for
|
||||||
# faster loading.
|
# faster loading.
|
||||||
np_folder = os.path.join(hf_folder, "np")
|
np_folder = os.path.join(hf_folder, "np")
|
||||||
os.makedirs(np_folder, exist_ok=True)
|
os.makedirs(np_folder, exist_ok=True)
|
||||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||||
with lock:
|
# Use file lock to prevent multiple processes from
|
||||||
|
# dumping the same model weights to numpy at the same time.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
if not os.path.exists(weight_names_file):
|
if not os.path.exists(weight_names_file):
|
||||||
weight_names = []
|
weight_names = []
|
||||||
for bin_file in hf_bin_files:
|
for bin_file in hf_weights_files:
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
param_path = os.path.join(np_folder, name)
|
param_path = os.path.join(np_folder, name)
|
||||||
@@ -68,16 +222,52 @@ def hf_model_weights_iterator(
|
|||||||
with open(param_path, "rb") as f:
|
with open(param_path, "rb") as f:
|
||||||
param = np.load(f)
|
param = np.load(f)
|
||||||
yield name, torch.from_numpy(param)
|
yield name, torch.from_numpy(param)
|
||||||
|
elif use_safetensors:
|
||||||
|
for st_file in hf_weights_files:
|
||||||
|
with safe_open(st_file, framework="pt") as f:
|
||||||
|
for name in f.keys():
|
||||||
|
param = f.get_slice(name)
|
||||||
|
yield name, param
|
||||||
else:
|
else:
|
||||||
for bin_file in hf_bin_files:
|
for bin_file in hf_weights_files:
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
|
del state
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||||
|
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||||
|
|
||||||
|
PySafeSlice object supports indexing, which is done before loading the
|
||||||
|
actual tensor and can reduce the amount of memory being read into the
|
||||||
|
memory. However, it does not support more advanced functionalities
|
||||||
|
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||||
|
tensor with these more complicated operators, we need to convert to
|
||||||
|
tensor first.
|
||||||
|
"""
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = x[:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def load_padded_tensor_parallel_vocab(
|
||||||
|
param: torch.Tensor,
|
||||||
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||||
|
tensor_model_parallel_rank: int,
|
||||||
|
) -> None:
|
||||||
|
shard_size = param.shape[0]
|
||||||
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
|
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
def load_tensor_parallel_weights(
|
def load_tensor_parallel_weights(
|
||||||
param: torch.Tensor,
|
param: torch.Tensor,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||||
param_name: str,
|
param_name: str,
|
||||||
column_parallel_weight_names: List[str],
|
column_parallel_weight_names: List[str],
|
||||||
row_parallel_weight_names: List[str],
|
row_parallel_weight_names: List[str],
|
||||||
@@ -97,6 +287,8 @@ def load_tensor_parallel_weights(
|
|||||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f"{param_name} shape mismatch between model and checkpoint: "
|
f"{param_name} shape mismatch between model and checkpoint: "
|
||||||
f"{param.shape} != {loaded_weight.shape}")
|
f"{param.shape} != {loaded_weight.shape}")
|
||||||
|
|||||||
@@ -75,10 +75,12 @@ class RequestOutput:
|
|||||||
# Get the top-n sequences.
|
# Get the top-n sequences.
|
||||||
n = seq_group.sampling_params.n
|
n = seq_group.sampling_params.n
|
||||||
seqs = seq_group.get_seqs()
|
seqs = seq_group.get_seqs()
|
||||||
assert n <= len(seqs)
|
if seq_group.sampling_params.use_beam_search:
|
||||||
sorted_seqs = sorted(seqs,
|
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||||
key=lambda seq: seq.get_cumulative_logprob(),
|
seq_group.sampling_params.length_penalty)
|
||||||
reverse=True)
|
else:
|
||||||
|
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||||
|
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||||
top_n_seqs = sorted_seqs[:n]
|
top_n_seqs = sorted_seqs[:n]
|
||||||
|
|
||||||
# Create the outputs.
|
# Create the outputs.
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user