Change the name to vLLM (#150)

This commit is contained in:
Woosuk Kwon
2023-06-17 03:07:40 -07:00
committed by GitHub
parent e5464ee484
commit 0b98ba15c7
90 changed files with 342 additions and 339 deletions

View File

@@ -17,9 +17,9 @@
# -- Project information -----------------------------------------------------
project = 'CacheFlow'
copyright = '2023, CacheFlow Team'
author = 'the CacheFlow Team'
project = 'vLLM'
copyright = '2023, vLLM Team'
author = 'the vLLM Team'
# -- General configuration ---------------------------------------------------
@@ -55,7 +55,7 @@ html_title = project
html_theme = 'sphinx_book_theme'
html_theme_options = {
'path_to_docs': 'docs/source',
'repository_url': 'https://github.com/WoosukKwon/cacheflow',
'repository_url': 'https://github.com/WoosukKwon/vllm',
'use_repository_button': True,
}

View File

@@ -1,8 +1,8 @@
Installation
============
CacheFlow is a Python library that includes some C++ and CUDA code.
CacheFlow can run on systems that meet the following requirements:
vLLM is a Python library that includes some C++ and CUDA code.
vLLM can run on systems that meet the following requirements:
* OS: Linux
* Python: 3.8 or higher
@@ -10,23 +10,23 @@ CacheFlow can run on systems that meet the following requirements:
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, etc.)
.. note::
As of now, CacheFlow does not support CUDA 12.
As of now, vLLM does not support CUDA 12.
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8.
.. tip::
If you have trouble installing CacheFlow, we recommend using the NVIDIA PyTorch Docker image.
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
Install with pip
----------------
You can install CacheFlow using pip:
You can install vLLM using pip:
.. code-block:: console
@@ -34,8 +34,8 @@ You can install CacheFlow using pip:
$ conda create -n myenv python=3.8 -y
$ conda activate myenv
$ # Install CacheFlow.
$ pip install cacheflow # This may take 5-10 minutes.
$ # Install vLLM.
$ pip install vllm # This may take 5-10 minutes.
.. _build_from_source:
@@ -43,10 +43,10 @@ You can install CacheFlow using pip:
Build from source
-----------------
You can also build and install CacheFlow from source.
You can also build and install vLLM from source.
.. code-block:: console
$ git clone https://github.com/WoosukKwon/cacheflow.git
$ cd cacheflow
$ git clone https://github.com/WoosukKwon/vllm.git
$ cd vllm
$ pip install -e . # This may take 5-10 minutes.

View File

@@ -8,7 +8,7 @@ Placeholder.
.. code-block:: python
from cacheflow import LLM, SamplingParams
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [

View File

@@ -1,5 +1,5 @@
Welcome to CacheFlow!
=====================
Welcome to vLLM!
================
Documentation
-------------

View File

@@ -3,30 +3,30 @@
Adding a New Model
==================
This document provides a high-level guide on integrating a `HuggingFace Transformers <https://github.com/huggingface/transformers>`_ model into CacheFlow.
This document provides a high-level guide on integrating a `HuggingFace Transformers <https://github.com/huggingface/transformers>`_ model into vLLM.
.. note::
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in CacheFlow.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
.. tip::
If you are encountering issues while integrating your model into CacheFlow, feel free to open an issue on our `GitHub <https://github.com/WoosukKwon/cacheflow/issues>`_ repository.
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/WoosukKwon/vllm/issues>`_ repository.
We will be happy to help you out!
0. Fork the CacheFlow repository
0. Fork the vLLM repository
--------------------------------
Start by forking our `GitHub <https://github.com/WoosukKwon/cacheflow/issues>`_ repository and then :ref:`build it from source <build_from_source>`.
Start by forking our `GitHub <https://github.com/WoosukKwon/vllm/issues>`_ repository and then :ref:`build it from source <build_from_source>`.
This gives you the ability to modify the codebase and test your model.
1. Bring your model code
------------------------
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `cacheflow/model_executor/models <https://github.com/WoosukKwon/cacheflow/tree/main/cacheflow/model_executor/models>`_ directory.
For instance, CacheFlow's `OPT model <https://github.com/WoosukKwon/cacheflow/blob/main/cacheflow/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/WoosukKwon/vllm/tree/main/vllm/model_executor/models>`_ directory.
For instance, vLLM's `OPT model <https://github.com/WoosukKwon/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
.. warning::
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
@@ -62,11 +62,11 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+) -> Dict[int, SequenceOutputs]:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTCacheFlowAttention` or :code:`GPTNeoXCacheFlowAttention`, depending on the model's architecture.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
.. note::
Currently, CacheFlow supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in CacheFlow.
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
3. (Optional) Implement tensor parallelism support
@@ -91,4 +91,4 @@ While the process is straightforward for most layers, the tensor-parallel layers
5. Register your model
----------------------
Finally, include your :code:`*ForCausalLM` class in `cacheflow/model_executor/models/__init__.py <https://github.com/WoosukKwon/cacheflow/blob/main/cacheflow/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `cacheflow/model_executor/model_loader.py <https://github.com/WoosukKwon/cacheflow/blob/main/cacheflow/model_executor/model_loader.py>`_.
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/WoosukKwon/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/WoosukKwon/vllm/blob/main/vllm/model_executor/model_loader.py>`_.

View File

@@ -3,8 +3,8 @@
Supported Models
================
CacheFlow supports a variety of generative Transformer models in `HuggingFace Transformers <https://github.com/huggingface/transformers>`_.
The following is the list of model architectures that are currently supported by CacheFlow.
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://github.com/huggingface/transformers>`_.
The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it.
.. list-table::
@@ -22,19 +22,19 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
If your model uses one of the above model architectures, you can seamlessly run your model with CacheFlow.
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
Alternatively, you can raise an issue on our `GitHub <https://github.com/WoosukKwon/cacheflow/issues>`_ project.
Alternatively, you can raise an issue on our `GitHub <https://github.com/WoosukKwon/vllm/issues>`_ project.
.. tip::
The easiest way to check if your model is supported is to run the program below:
.. code-block:: python
from cacheflow import LLM
from vllm import LLM
llm = LLM(model=...) # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
If CacheFlow successfully generates text, it indicates that your model is supported.
If vLLM successfully generates text, it indicates that your model is supported.