[Doc][2/N] Reorganize Models and Usage sections (#11755)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-06 21:40:31 +08:00
committed by GitHub
parent 996357e480
commit ee77fdb5de
45 changed files with 265 additions and 238 deletions

View File

@@ -0,0 +1,102 @@
(new-model-basic)=
# Basic Implementation
This guide walks you through the steps to implement a basic vLLM model.
## 1. Bring your model code
First, clone the PyTorch model code from the source repository.
For instance, vLLM's [OPT model](gh-file:vllm/model_executor/models/opt.py) was adapted from
HuggingFace's [modeling_opt.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py) file.
```{warning}
Make sure to review and adhere to the original code's copyright and licensing terms!
```
## 2. Make your code compatible with vLLM
To ensure compatibility with vLLM, your model must meet the following requirements:
### Initialization Code
All vLLM modules within the model must include a `prefix` argument in their constructor. This `prefix` is typically the full name of the module in the model's state dictionary and is crucial for:
- Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts.
- Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the `prefix` during initialization, vLLM can match the current layer's `prefix` with the quantization configuration to determine if the layer should be initialized in quantized mode.
The initialization code should look like this:
```python
from torch import nn
from vllm.config import VllmConfig
from vllm.attention import Attention
class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.attn = Attention(prefix=f"{prefix}.attn")
class MyDecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
class MyModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.layers = nn.ModuleList(
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
)
class MyModelForCausalLM(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
```
### Computation Code
Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
```python
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
```
```{note}
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
```
For reference, check out our [Llama implementation](gh-file:vllm/model_executor/models/llama.py). vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out <gh-dir:vllm/model_executor/models> for more examples.
## 3. (Optional) Implement tensor parallelism and quantization support
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace {class}`torch.nn.Embedding` with `VocabParallelEmbedding`. For the output LM head, you can use `ParallelLMHead`.
When it comes to the linear layers, we provide the following options to parallelize them:
- `ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
- `RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
- `ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
- `MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
- `QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
## 4. Implement the weight loading logic
You now need to implement the `load_weights` method in your `*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
## 5. Register your model
See [this page](#new-model-registration) for instructions on how to register your new model to be used by vLLM.

View File

@@ -0,0 +1,26 @@
(new-model)=
# Adding a New Model
This section provides more information on how to integrate a [HuggingFace Transformers](https://github.com/huggingface/transformers) model into vLLM.
```{toctree}
:caption: Contents
:maxdepth: 1
basic
registration
multimodal
```
```{note}
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
```
```{tip}
If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues)
or ask on our [developer slack](https://slack.vllm.ai).
We will be happy to help you out!
```

View File

@@ -0,0 +1,139 @@
(enabling-multimodal-inputs)=
# Enabling Multimodal Inputs
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs).
## 1. Update the base vLLM model
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
Further update the model as follows:
- Implement the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
```diff
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
```{note}
The model class does not have to be named {code}`*ForCausalLM`.
Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
```
- If you haven't already done so, reserve a keyword parameter in {meth}`~torch.nn.Module.forward`
for each input tensor that corresponds to a multi-modal input, as shown in the following example:
```diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```
## 2. Register input mappers
For each modality type that the model accepts as input, decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in {meth}`~torch.nn.Module.forward`.
```diff
from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```
## 3. Register maximum number of multi-modal tokens
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
and register it via {meth}`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
```diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
Here are some examples:
- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```
## 4. (Optional) Register dummy data
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via {meth}`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
```diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
```{note}
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
```
Here are some examples:
- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```
## 5. (Optional) Register input processor
Sometimes, there is a need to process inputs at the {class}`~vllm.LLMEngine` level before they are passed to the model executor.
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's {meth}`~torch.nn.Module.forward` call.
You can register input processors via {meth}`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.
```diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
- Insert static number of image tokens: [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
- Insert dynamic number of image tokens: [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```

View File

@@ -0,0 +1,56 @@
(new-model-registration)=
# Model Registration
vLLM relies on a model registry to determine how to run each model.
A list of pre-registered architectures can be found on the [Supported Models](#supported-models) page.
If your model is not on this list, you must register it to vLLM.
This page provides detailed instructions on how to do so.
## Built-in models
To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source](#build-from-source).
This gives you the ability to modify the codebase and test your model.
After you have implemented your model (see [tutorial](#new-model-basic)), put it into the <gh-dir:vllm/model_executor/models> directory.
Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM.
You should also include an example HuggingFace repository for this model in <gh-file:tests/models/registry.py> to run the unit tests.
Finally, update the [Supported Models](#supported-models) documentation page to promote your model!
```{important}
The list of models in each section should be maintained in alphabetical order.
```
## Out-of-tree models
You can load an external model using a plugin without modifying the vLLM codebase.
```{seealso}
[vLLM's Plugin System](#plugin-system)
```
To register the model, use the following code:
```python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
```
If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`:
```python
from vllm import ModelRegistry
ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM")
```
```{important}
If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
Read more about that [here](#enabling-multimodal-inputs).
```
```{note}
Although you can directly put these code snippets in your script using `vllm.LLM`, the recommended way is to place these snippets in a vLLM plugin. This ensures compatibility with various vLLM features like distributed inference and the API server.
```