[Core] Add Support for Default Modality Specific LoRAs [generate / chat completions] (#19126)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Default LoRA Models For Multimodal Models
|
||||
|
||||
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
|
||||
|
||||
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
|
||||
|
||||
Example usage for offline inference:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def get_prompt(question: str, has_audio: bool):
|
||||
"""Build the input prompt to send to vLLM."""
|
||||
if has_audio:
|
||||
question = f"<|audio|>{question}"
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
|
||||
|
||||
model = LLM(
|
||||
model=model_id,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_model_len=2048,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
# Will always pass a `LoRARequest` with the `model_id`
|
||||
# whenever audio is contained in the request data.
|
||||
default_mm_loras = {"audio": model_id},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
question = "can you transcribe the speech into a written format?"
|
||||
prompt_with_audio = get_prompt(
|
||||
question=question,
|
||||
has_audio=True,
|
||||
)
|
||||
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt_with_audio,
|
||||
"multi_modal_data": {
|
||||
"audio": audio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.2,
|
||||
max_tokens=64,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
|
||||
|
||||
```bash
|
||||
vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||
--max-model-len 2048 \
|
||||
--enable-lora \
|
||||
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
|
||||
--max-lora-rank 64
|
||||
```
|
||||
|
||||
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||
|
||||
Reference in New Issue
Block a user