[Model]: Add transformers backend support (#11330)
# Adds support for `transformers` as a backend Following https://github.com/huggingface/transformers/pull/35235, a bunch of models should already be supported, we are ramping up support for more models. Thanks @Isotr0py for the TP support, and @hmellor for his help as well! This includes: - `trust_remote_code=True` support: any model on the hub, if it implements attention the correct way can be natively supported!! - tensor parallel support --------- Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -40,6 +40,82 @@ If vLLM successfully returns text (for generative models) or hidden states (for
|
||||
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
|
||||
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
|
||||
|
||||
### Transformers fallback
|
||||
|
||||
After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
||||
|
||||
To check if the backend is `transformers`, you can simply do this:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
llm.apply_model(lambda model: print(model.__class__))
|
||||
```
|
||||
|
||||
If it is `TransformersModel` then it means it's based on `transformers`!
|
||||
|
||||
#### Supported features
|
||||
|
||||
##### LORA and quantization
|
||||
|
||||
Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
||||
|
||||
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
|
||||
|
||||
Hints as to how this would look like:
|
||||
|
||||
```python
|
||||
class TransformersModel(nn.Module, SupportsLoRA):
|
||||
def __init__(*):
|
||||
...
|
||||
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
|
||||
```
|
||||
|
||||
Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
|
||||
|
||||
##### Remote code
|
||||
|
||||
This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
|
||||
llm.apply_model(lambda model: print(model.__class__))
|
||||
```
|
||||
|
||||
A model just needs the following two things:
|
||||
|
||||
```python
|
||||
from transformers import PreTrainedModel
|
||||
from torch import nn
|
||||
|
||||
class MyAttention(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, **kwargs): # <- kwargs are required
|
||||
|
||||
...
|
||||
attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
**kwargs,
|
||||
)
|
||||
...
|
||||
|
||||
class MyModel(PreTrainedModel):
|
||||
_supports_attention_backend = True
|
||||
```
|
||||
|
||||
Here is what happens in the background:
|
||||
|
||||
1. The config is loaded
|
||||
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||
3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
|
||||
That's it!
|
||||
|
||||
### ModelScope
|
||||
|
||||
To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable:
|
||||
|
||||
Reference in New Issue
Block a user