[Doc] Further update multi-modal impl doc (#33065)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -43,28 +43,73 @@ Further update the model as follows:
|
||||
)
|
||||
```
|
||||
|
||||
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
- Remove the embedding part from the [forward][torch.nn.Module.forward] method:
|
||||
- Move the multi-modal embedding to [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
|
||||
- The text embedding and embedding merge are handled automatically by a default implementation of [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids]. It does not need to be overridden in most cases.
|
||||
|
||||
??? code
|
||||
```diff
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
- pixel_values: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
- if inputs_embeds is None:
|
||||
- inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
-
|
||||
- if pixel_values is not None:
|
||||
- image_features = self.get_image_features(
|
||||
- pixel_values=pixel_values,
|
||||
- )
|
||||
- special_image_mask = self.get_placeholder_mask(
|
||||
- input_ids,
|
||||
- inputs_embeds=inputs_embeds,
|
||||
- image_features=image_features,
|
||||
- )
|
||||
- inputs_embeds = inputs_embeds.masked_scatter(
|
||||
- special_image_mask,
|
||||
- image_features,
|
||||
- )
|
||||
|
||||
```python
|
||||
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
hidden_states = self.language_model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
...
|
||||
|
||||
+ def embed_multimodal(
|
||||
+ self,
|
||||
+ pixel_values: torch.Tensor,
|
||||
+ ) -> MultiModalEmbeddings | None:
|
||||
+ return self.get_image_features(
|
||||
+ pixel_values=pixel_values,
|
||||
+ )
|
||||
```
|
||||
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings | None:
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
Below we provide a boilerplate of a typical implementation pattern of [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal], but feel free to adjust it to your own needs.
|
||||
|
||||
# Run multimodal inputs through encoder and projector
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
```
|
||||
```python
|
||||
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings | None:
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
# Run multimodal inputs through encoder and projector
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
```
|
||||
|
||||
!!! important
|
||||
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||
|
||||
Reference in New Issue
Block a user