Compare commits

..

15 Commits

Author SHA1 Message Date
Woosuk Kwon
31c1f3255e Bump up to v0.2.5 (#2095)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.8, 2.1.1) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.1.1) (push) Has been cancelled
2023-12-13 23:56:15 -08:00
Antoni Baum
21d93c140d Optimize Mixtral with expert parallelism (#2090) 2023-12-13 23:55:07 -08:00
Woosuk Kwon
f1c8520146 [BugFix] Fix input positions for long context with sliding window (#2088) 2023-12-13 12:28:13 -08:00
Woosuk Kwon
096827c284 [Docs] Add notes on ROCm-supported models (#2087) 2023-12-13 09:45:34 -08:00
Woosuk Kwon
6565d9e33e Update installation instruction for vLLM + CUDA 11.8 (#2086) 2023-12-13 09:25:59 -08:00
TJian
f375ec8440 [ROCm] Upgrade xformers version for ROCm & update doc (#2079)
Co-authored-by: miloice <jeffaw99@hotmail.com>
2023-12-13 00:56:05 -08:00
Woosuk Kwon
518369d78c Implement lazy model loader (#2044) 2023-12-12 22:21:45 -08:00
Woosuk Kwon
30bad5c492 Fix peak memory profiling (#2031) 2023-12-12 22:01:53 -08:00
Simon Mo
3fefe271ec Update Dockerfile to build Megablocks (#2042) 2023-12-12 17:34:17 -08:00
Megha Agarwal
6428f1d051 Support MPT with GQA (#1938)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2023-12-12 10:16:05 -08:00
Woosuk Kwon
7e1b21daac Remove einops from requirements (#2049) 2023-12-12 09:34:09 -08:00
Woosuk Kwon
cb3f30c600 Upgrade transformers version to 4.36.0 (#2046) 2023-12-11 18:39:14 -08:00
Woosuk Kwon
f3e024bece [CI/CD] Upgrade PyTorch version to v2.1.1 (#2045) 2023-12-11 17:48:11 -08:00
Woosuk Kwon
31d2ab4aff Remove python 3.10 requirement (#2040) 2023-12-11 12:26:42 -08:00
Simon Mo
eb17212858 Update Dockerfile to support Mixtral (#2027) 2023-12-11 11:59:08 -08:00
27 changed files with 514 additions and 521 deletions

View File

@@ -49,7 +49,7 @@ jobs:
matrix: matrix:
os: ['ubuntu-20.04'] os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11'] python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.1.0'] pytorch-version: ['2.1.1']
cuda-version: ['11.8', '12.1'] cuda-version: ['11.8', '12.1']
steps: steps:

View File

@@ -75,7 +75,7 @@ ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
FROM vllm-base AS vllm-openai FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server # install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate fschat pip install accelerate
COPY --from=build /workspace/vllm/*.so /workspace/vllm/ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm COPY vllm vllm

View File

@@ -47,12 +47,12 @@ RUN mkdir libs \
COPY ./ /app/vllm COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.22.post7 --no-deps RUN pip install xformers==0.0.23 --no-deps
RUN cd /app \ RUN cd /app \
&& cd vllm \ && cd vllm \
&& pip install -U -r requirements-rocm.txt \ && pip install -U -r requirements-rocm.txt \
&& bash patch_xformers-0.0.22.post7.rocm.sh \ && bash patch_xformers-0.0.23.rocm.sh \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..

View File

@@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
```bash ```bash
pip install vllm pip install vllm
``` ```
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks) on **Python 3.10**:
```bash
pip install megablocks
```
## Getting Started ## Getting Started

View File

@@ -3,7 +3,7 @@
Installation with ROCm Installation with ROCm
====================== ======================
vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
Data types currently supported in ROCm are FP16 and BF16. Data types currently supported in ROCm are FP16 and BF16.
@@ -29,7 +29,7 @@ Installation options:
.. code-block:: console .. code-block:: console
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3 $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
$ docker run -it \ $ docker run -it \
--network=host \ --network=host \
--group-add=video \ --group-add=video \
@@ -70,12 +70,12 @@ You can build and install vLLM from source:
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention 2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console .. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps $ pip install xformers==0.0.23 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh $ bash patch_xformers.rocm.sh
3. Build vLLM. 3. Build vLLM.
@@ -127,12 +127,12 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention 2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console .. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps $ pip install xformers==0.0.23 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh $ bash patch_xformers.rocm.sh
3. Build vLLM. 3. Build vLLM.

View File

@@ -20,7 +20,7 @@ You can install vLLM using pip:
.. code-block:: console .. code-block:: console
$ # (Optional) Create a new conda environment. $ # (Optional) Create a new conda environment.
$ conda create -n myenv python=3.8 -y $ conda create -n myenv python=3.9 -y
$ conda activate myenv $ conda activate myenv
$ # Install vLLM with CUDA 12.1. $ # Install vLLM with CUDA 12.1.
@@ -34,8 +34,9 @@ You can install vLLM using pip:
.. code-block:: console .. code-block:: console
$ # Install vLLM with CUDA 11.8. $ # Install vLLM with CUDA 11.8.
$ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`). $ export VLLM_VERSION=0.2.4
$ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl $ export PYTHON_VERSION=39
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
$ # Re-install PyTorch with CUDA 11.8. $ # Re-install PyTorch with CUDA 11.8.
$ pip uninstall torch -y $ pip uninstall torch -y

View File

@@ -73,6 +73,9 @@ If your model uses one of the above model architectures, you can seamlessly run
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model. Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project. Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
.. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
.. tip:: .. tip::
The easiest way to check if your model is supported is to run the program below: The easiest way to check if your model is supported is to run the program below:
@@ -84,12 +87,17 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
output = llm.generate("Hello, my name is") output = llm.generate("Hello, my name is")
print(output) print(output)
To use model from www.modelscope.cn If vLLM successfully generates text, it indicates that your model is supported.
.. tip::
To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
.. code-block:: shell .. code-block:: shell
$ export VLLM_USE_MODELSCOPE=True $ export VLLM_USE_MODELSCOPE=True
And use with :code:`trust_remote_code=True`.
.. code-block:: python .. code-block:: python
from vllm import LLM from vllm import LLM
@@ -97,5 +105,3 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
output = llm.generate("Hello, my name is") output = llm.generate("Hello, my name is")
print(output) print(output)
If vLLM successfully generates text, it indicates that your model is supported.

View File

@@ -1,21 +1,32 @@
#!/bin/bash #!/bin/bash
set -e
XFORMERS_VERSION="0.0.23"
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
exit 1
fi
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
echo $XFORMERS_FMHA_FLASH_PATH echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
echo $XFORMERS_FMHA_COMMON_PATH echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch" patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
else else
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
fi fi
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch" patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
else else
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"

View File

@@ -4,7 +4,7 @@ requires = [
"ninja", "ninja",
"packaging", "packaging",
"setuptools >= 49.4.0", "setuptools >= 49.4.0",
"torch >= 2.1.0", "torch >= 2.1.1",
"wheel", "wheel",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -8,9 +8,7 @@ pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
tokenizers>=0.15.0 tokenizers>=0.15.0
huggingface_hub<0.18,>=0.16.4 transformers >= 4.36.0 # Required for Mixtral.
einops # Required for phi-1_5
transformers >= 4.34.0 # Required for Mistral.
fastapi fastapi
uvicorn[standard] uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server. pydantic == 1.10.13 # Required for OpenAI server.

View File

@@ -5,10 +5,9 @@ pandas # Required for Ray data.
pyarrow # Required for Ray data. pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
einops # Required for phi-1_5 torch >= 2.1.1
torch >= 2.1.0 transformers >= 4.36.0 # Required for Mixtral.
transformers >= 4.34.0 # Required for Mistral. xformers >= 0.0.23 # Required for CUDA 12.1.
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
fastapi fastapi
uvicorn[standard] uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server. pydantic == 1.10.13 # Required for OpenAI server.

View File

@@ -1,6 +1,6 @@
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 --- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000 +++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
@@ -31,39 +31,39 @@ @@ -36,44 +36,44 @@
FLASH_VERSION = "0.0.0" FLASH_VERSION = "0.0.0"
try: try:
@@ -15,9 +15,12 @@
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
- -
- FLASH_VERSION = flash_attn.__version__ - FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- if flash_ver_parsed < (2, 3): - if (
- raise ImportError("Requires 2.3 for sliding window support") - flash_ver_parsed != (2, 3, 6)
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- ):
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
+ #try: + #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined] + # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata + # from ..._cpp_lib import _build_metadata
@@ -29,35 +32,41 @@
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+ +
+ FLASH_VERSION = flash_attn.__version__ + FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+ # if flash_ver_parsed < (2, 3): + # if (
+ # raise ImportError("Requires 2.3 for sliding window support") + # flash_ver_parsed != (2, 3, 6)
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+ # ):
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
# create library so that flash-attn goes through the PyTorch Dispatcher # create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF") - _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") -
- _flash_lib.define( - _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, " - "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- "int max_seqlen_q, int max_seqlen_k, " - "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, " - "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" - "bool is_causal, int window_left, "
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- ) - )
- + #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define( - _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, " - "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" - "float p, float softmax_scale, bool is_causal, "
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- ) - )
+ #_flash_lib.define( + #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, " + # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
+ # "int max_seqlen_q, int max_seqlen_k, " + # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, " + # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" + # "bool is_causal, int window_left, "
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #) + #)
+ +
+ #_flash_lib.define( + #_flash_lib.define(
@@ -65,52 +74,61 @@
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, " + # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + # "float p, float softmax_scale, bool is_causal, "
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #) + #)
def _flash_fwd( def _flash_fwd(
query, query,
@@ -98,8 +98,8 @@ @@ -111,8 +111,8 @@
p, p,
softmax_scale, softmax_scale,
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left, # window_size_left
- -1, # window_size_right - window_right, # window_size_right
+ # window_size - 1, # window_size_left + # window_left, # window_size_left
+ # -1, # window_size_right + # window_right, # window_size_right
return_softmax, return_softmax,
None, # rng None, # rng
) )
@@ -127,8 +127,8 @@ @@ -134,15 +134,15 @@
out,
cu_seq_lens_q,
cu_seq_lens_k,
- seqused_k,
+ # seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale, softmax_scale,
False, False,
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left,
- -1, # window_size_right - window_right,
+ # window_size - 1, # window_size_left + # window_left,
+ # -1, # window_size_right + # window_right,
return_softmax, return_softmax,
None, None,
) )
@@ -169,8 +169,8 @@ @@ -184,8 +184,8 @@
p, p,
softmax_scale, softmax_scale,
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left,
- -1, # window_size_right - window_right,
+ # window_size - 1, # window_size_left + # window_left,
+ # -1, # window_size_right + # window_right,
None, None,
rng_state, rng_state,
) )
@@ -193,15 +193,15 @@ @@ -208,15 +208,15 @@
softmax_scale, softmax_scale,
False, # zero_tensors False, # zero_tensors
is_causal, is_causal,
- window_size - 1, # window_size_left - window_left,
- -1, # window_size_right - window_right,
+ # window_size - 1, # window_size_left + # window_left,
+ # -1, # window_size_right + # window_right,
None, None,
rng_state, rng_state,
) )
@@ -123,7 +141,7 @@
except ImportError: except ImportError:
pass pass
@@ -348,7 +348,7 @@ @@ -400,7 +400,7 @@
implementation. implementation.
""" """

View File

@@ -1,3 +1,4 @@
import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import pytest import pytest
@@ -7,21 +8,32 @@ from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
_TEST_PROMPTS = [ _TEST_PROMPTS = ["prompts/example.txt"]
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", _LONG_PROMPTS = ["prompts/summary.txt"]
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
"Describe the basic components of a neural network and how it can be trained.", def _read_prompts(filename: str) -> str:
"Write a short story about a robot that dreams for the first time.", prompts = []
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", with open(filename, "r") as f:
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", prompt = f.readline()
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", prompts.append(prompt)
] return prompts
@pytest.fixture @pytest.fixture
def example_prompts() -> List[str]: def example_prompts() -> List[str]:
return _TEST_PROMPTS prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
return prompts
@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
return prompts
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {

View File

@@ -0,0 +1,37 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py --forked`.
"""
import pytest
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(
hf_runner,
vllm_runner,
example_long_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
del vllm_model
for i in range(len(example_long_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@@ -0,0 +1,8 @@
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
Describe the basic components of a neural network and how it can be trained.
Write a short story about a robot that dreams for the first time.
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'

File diff suppressed because one or more lines are too long

View File

@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.2.4" __version__ = "0.2.5"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -120,14 +120,16 @@ class ModelConfig:
if load_format == "auto": if load_format == "auto":
load_format = "pt" load_format = "pt"
# FIXME(woosuk): This is a temporary hack. Support safetensor weights. # TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt": if "MixtralForCausalLM" in architectures:
logger.info( if load_format == "pt":
"Currently, only 'pt' format is supported for Mixtral. " raise ValueError(
"Changing the format to 'pt'. This may re-download the " "Currently, the 'pt' format is not supported for Mixtral. "
"weights if you have downloaded the safetensor weights.") "Please use the 'safetensors' format instead. ")
load_format = "pt" elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"
self.load_format = load_format self.load_format = load_format

View File

@@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
input_metadata.attn_bias = attn_bias input_metadata.attn_bias = attn_bias
else: else:
input_metadata.attn_bias = _make_alibi_bias( input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, batch_size, seq_len, query.dtype) self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
# TODO(woosuk): Too many view operations. Let's try to reduce them # TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability. # in the future for code readability.
@@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias: ) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype, device="cuda")
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
# paper. # paper.
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
bias = bias.to(alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to # When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8. # be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8 padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty( bias = torch.empty(
batch_size, batch_size,
alibi_slopes.shape[0], num_heads,
seq_len, seq_len,
padded_len, padded_len,
device=alibi_slopes.device, device=alibi_slopes.device,
dtype=dtype, dtype=dtype,
)[:, :, :, :seq_len].copy_(bias) )[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias return attn_bias

View File

@@ -7,54 +7,9 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import * from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
from vllm.utils import is_hip
from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"AquilaModel": AquilaForCausalLM,
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForConditionalGeneration": ChatGLMForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
"YiForCausalLM": YiForCausalLM,
}
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
del _MODEL_REGISTRY[rocm_model]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not supported in ROCm's flash attention",
}
@contextlib.contextmanager @contextlib.contextmanager
@@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
if arch in _MODEL_REGISTRY: model_cls = ModelRegistry.load_model_cls(arch)
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: if model_cls is not None:
logger.warning( return model_cls
f"{arch} is not fully supported in ROCm. Reason: "
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
return _MODEL_REGISTRY[arch]
elif arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {arch} is not supported by ROCm for now. \n"
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:

View File

@@ -1,41 +1,82 @@
from vllm.model_executor.models.aquila import AquilaForCausalLM import importlib
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM, from typing import List, Optional, Type
BaichuanForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM import torch.nn as nn
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.logger import init_logger
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.utils import is_hip
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM logger = init_logger(__name__)
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM # Architecture -> (module, class).
from vllm.model_executor.models.mistral import MistralForCausalLM _MODELS = {
from vllm.model_executor.models.mixtral import MixtralForCausalLM "AquilaModel": ("aquila", "AquilaForCausalLM"),
from vllm.model_executor.models.mpt import MPTForCausalLM "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
from vllm.model_executor.models.opt import OPTForCausalLM "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
from vllm.model_executor.models.qwen import QWenLMHeadModel "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
from vllm.model_executor.models.yi import YiForCausalLM "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM"),
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
}
class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
__all__ = [ __all__ = [
"AquilaForCausalLM", "ModelRegistry",
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"ChatGLMForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
"GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM",
"PhiForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
"MixtralForCausalLM",
"YiForCausalLM",
] ]

View File

@@ -29,25 +29,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MistralConfig from transformers import MixtralConfig
try:
import megablocks.ops as ops
except ImportError:
print(
"MegaBlocks not found. Please install it by `pip install megablocks`. "
"Note that MegaBlocks depends on mosaicml-turbo, which only supports "
"Python 3.10 for now.")
try:
import stk
except ImportError:
print(
"STK not found: please see https://github.com/stanford-futuredata/stk")
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -67,8 +55,134 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
def promote_scalar(x: torch.Tensor) -> torch.Tensor: class MixtralMLP(nn.Module):
return x.view(1) if len(x.size()) == 0 else x
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class DummyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Linear(0, 0, bias=False)
self.w2 = nn.Linear(0, 0, bias=False)
self.w3 = nn.Linear(0, 0, bias=False)
set_weight_attrs(self.w1.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w2.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w3.weight,
{"weight_loader": self.dummy_weight_loader})
def forward(self, *args, **kwargs) -> None:
raise NotImplementedError()
def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
# Noop
return
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else DummyModule()
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
linear_method=linear_method)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
@@ -79,6 +193,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -103,24 +218,26 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.wqkv = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method,
) )
self.wo = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=False, # weights not in HF format is_neox_style=True,
) )
self.attn = PagedAttention( self.attn = PagedAttention(
self.num_heads, self.num_heads,
@@ -138,334 +255,93 @@ class MixtralAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event) cache_event)
output, _ = self.wo(attn_output) output, _ = self.o_proj(attn_output)
return output return output
class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, hidden_dim: int, ffn_dim: int, num_experts: int,
top_k: int):
super().__init__()
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_experts = num_experts
self.top_k = top_k
# gating
self.gate = nn.Linear(self.hidden_dim,
self.num_experts,
bias=False,
device=torch.cuda.current_device())
tp_size = get_tensor_model_parallel_world_size()
assert self.ffn_dim % tp_size == 0
self.ffn_dim_per_partition = self.ffn_dim // tp_size
# merged expert weights, all of size (ffn_dim * n_experts, model_dim)
self.w1 = nn.Parameter(
torch.empty(self.ffn_dim_per_partition * self.num_experts,
self.hidden_dim,
device=torch.cuda.current_device()))
set_weight_attrs(self.w1, {"weight_loader": self.moe_weight_loader})
self.w2 = nn.Parameter(
torch.empty(self.ffn_dim_per_partition * self.num_experts,
self.hidden_dim,
device=torch.cuda.current_device()))
set_weight_attrs(self.w2, {"weight_loader": self.moe_weight_loader})
self.w3 = nn.Parameter(
torch.empty(self.ffn_dim_per_partition * self.num_experts,
self.hidden_dim,
device=torch.cuda.current_device()))
set_weight_attrs(self.w3, {"weight_loader": self.moe_weight_loader})
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1
# Calculate the number of bits needed to represent the column indices
# in the intermediate sparse matrix.
max_column_index = (self.ffn_dim * self.num_experts) // self.blocking
self.transpose_sort_end_bit = max(
int(np.ceil(np.log2(max_column_index))), 1)
def moe_weight_loader(self, param: nn.Parameter,
loaded_weight: torch.Tensor) -> None:
"""
Load the weights for the MoE linear layer.
"""
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.ffn_dim_per_partition
loaded_weight = loaded_weight.view(self.num_experts, self.ffn_dim, -1)
loaded_weight = loaded_weight[:, shard_size * tp_rank:shard_size *
(tp_rank + 1)]
loaded_weight = loaded_weight.reshape_as(param)
param.data.copy_(loaded_weight)
def sparse_transpose(
self, size: int, row_indices,
column_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
block_columns = size[1] // self.blocking
# Sort row indices by column indices to get the transposed matrix's
# column indices.
#
# NOTE: Our sort operation uses the same width indices as the input
# values. To avoid overflow when we have large activation matrices
# we cast to 32-bit before sorting.
_, gather_indices = ops.sort(column_indices.int(),
self.transpose_sort_end_bit)
# There are a constant number of blocks in every row of the sparse
# matrix. A blocks offset is:
#
# row_index * blocks_per_row + column_index % blocks_per_row
#
# Once we have the block offsets ordered for transposition we can
# divide by blocks_per_row to get the transposed column indices.
column_indices_t = row_indices.gather(0, gather_indices.long())
block_offsets_t = gather_indices.int()
zero = torch.zeros((1, ), dtype=torch.int32, device=row_indices.device)
nnz_per_column = ops.histogram(column_indices, block_columns)
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
offsets_t = torch.cat([zero, nnz_per_column])
return column_indices_t, offsets_t, block_offsets_t
def topology(self, x: torch.Tensor,
padded_bins: torch.Tensor) -> "stk.Matrix":
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim_per_partition % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim_per_partition // self.blocking
offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim_per_partition * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
shape, row_indices, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
column_indices_t,
offsets_t,
block_offsets_t,
)
def indices_and_padded_bins(
self, selected_experts: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
selected_experts = selected_experts.int()
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
padded_tokens_per_expert = ops.round_up(tokens_per_expert,
self.blocking)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
return indices, bin_ids, bins, padded_bins, tokens_per_expert
@torch.inference_mode()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = F.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.flatten().to(x.dtype)
selected_experts = selected_experts.flatten()
(indices, bin_ids, bins, padded_bins,
_) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
F.silu(stk.ops.sdd(x, self.w1.t(), topo).data) *
stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
x = tensor_model_parallel_all_reduce(x)
# Permute back and remove padding
# (top_k * sequence_length, model_dim)
x = ops.padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
self.top_k,
self.quantize_scatter_num_bits,
)
return x.view(*input_shape)
class MixtralDecoderLayer(nn.Module): class MixtralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
self.attention = MixtralAttention( self.self_attn = MixtralAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
sliding_window=config.sliding_window) sliding_window=config.sliding_window,
self.block_sparse_moe = BlockSparseMoE( linear_method=linear_method)
hidden_dim=self.hidden_size, self.block_sparse_moe = MixtralMoE(config=config,
ffn_dim=config.intermediate_size, linear_method=linear_method)
num_experts=config.num_local_experts, self.input_layernorm = RMSNorm(config.hidden_size,
top_k=config.num_experts_per_tok, eps=config.rms_norm_eps)
) self.post_attention_layernorm = RMSNorm(config.hidden_size,
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
x: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
r = self.attention( # Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=self.attention_norm(x), hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event, cache_event=cache_event,
) )
h = x + r
r = self.block_sparse_moe(self.ffn_norm(h)) # Fully Connected
out = h + r hidden_states, residual = self.post_attention_layernorm(
return out hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralForCausalLM(nn.Module): class MixtralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
assert linear_method is None
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.tok_embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config) MixtralDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@@ -475,20 +351,42 @@ class MixtralForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> SamplerOutput:
hidden_states = self.tok_embeddings(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None
# forward
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i] cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states, residual = layer(positions, hidden_states,
positions, kv_caches[i], input_metadata,
hidden_states, cache_event, residual)
kv_caches[i], hidden_states, _ = self.norm(hidden_states, residual)
input_metadata, return hidden_states
cache_event,
)
hidden_states = self.norm(hidden_states) class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states return hidden_states
def sample( def sample(
@@ -496,7 +394,7 @@ class MixtralForCausalLM(nn.Module):
hidden_states: Optional[torch.Tensor], hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
next_tokens = self.sampler(self.output.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
@@ -507,10 +405,11 @@ class MixtralForCausalLM(nn.Module):
revision: Optional[str] = None): revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("wqkv", "wq", "q"), ("qkv_proj", "q_proj", "q"),
("wqkv", "wk", "k"), ("qkv_proj", "k_proj", "k"),
("wqkv", "wv", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):

View File

@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.total_num_heads = config.n_heads self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.clip_qkv = config.attn_config["clip_qkv"] self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"] self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"] self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"] assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"] assert config.attn_config["alibi"]
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
self.d_model, self.d_model,
self.d_model // self.total_num_heads, self.d_model // self.total_num_heads,
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias, bias=not config.no_bias,
linear_method=linear_method, linear_method=linear_method,
) )
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0 assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln: if self.qk_ln:
q = self.q_ln(q) q = self.q_ln(q)
k = self.k_ln(k) k = self.k_ln(k)

View File

@@ -40,11 +40,6 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
return int(max_shared_mem) return int(max_shared_mem)
def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int: def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes.""" """Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total return psutil.virtual_memory().total

View File

@@ -134,14 +134,14 @@ class ModelRunner:
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token]) input_tokens.append([generation_token])
context_len = seq_data.get_len() seq_len = seq_data.get_len()
if self.sliding_window is not None: position = seq_len - 1
context_len = min(context_len, self.sliding_window)
context_lens.append(context_len)
position = context_len - 1
input_positions.append([position]) input_positions.append([position])
context_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
block_offset = position % self.block_size block_offset = position % self.block_size

View File

@@ -13,7 +13,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.utils import get_gpu_memory
class Worker: class Worker:
@@ -81,7 +80,6 @@ class Worker:
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
@@ -90,8 +88,9 @@ class Worker:
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
torch.cuda.synchronize() torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
total_gpu_memory = get_gpu_memory() peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config) block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int( num_gpu_blocks = int(