[V1] Support VLMs with fine-grained scheduling (#9871)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@@ -448,6 +449,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def process_mm_inputs(self, **kwargs):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if vision_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -455,6 +475,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LLaVA-1.5.
|
||||
@@ -494,24 +515,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"""
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
else:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.process_mm_inputs(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
Reference in New Issue
Block a user