[V1] Support VLMs with fine-grained scheduling (#9871)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -39,6 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_list_of
|
||||
@@ -500,15 +501,20 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
|
||||
# TODO: Move this to utils or integrate with clip.
|
||||
new_token_ids: List[int] = []
|
||||
placeholder_ranges: List[PlaceholderRange] = []
|
||||
placeholder_idx = 0
|
||||
while merged_token_ids:
|
||||
token_id = merged_token_ids.pop(0)
|
||||
if token_id == _IMAGE_TOKEN_ID:
|
||||
new_token_ids.extend(
|
||||
repeat_and_pad_token(
|
||||
_IMAGE_TOKEN_ID,
|
||||
repeat_count=image_feature_size[placeholder_idx],
|
||||
))
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
_IMAGE_TOKEN_ID,
|
||||
repeat_count=image_feature_size[placeholder_idx],
|
||||
)
|
||||
placeholder_ranges.append({
|
||||
"offset": len(new_token_ids),
|
||||
"length": len(replacement_ids)
|
||||
})
|
||||
new_token_ids.extend(replacement_ids)
|
||||
placeholder_idx += 1
|
||||
else:
|
||||
new_token_ids.append(token_id)
|
||||
@@ -516,7 +522,8 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": placeholder_ranges})
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@@ -669,32 +676,42 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return image_embeds
|
||||
|
||||
def process_mm_inputs(self, **kwargs):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if vision_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
else:
|
||||
inputs_embeds = self.language_model.model.embed_tokens(
|
||||
input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.process_mm_inputs(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
Reference in New Issue
Block a user