[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -1,3 +1,4 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -19,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
@@ -475,7 +476,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: Blip2Config,
|
||||
@@ -508,6 +509,16 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
@@ -600,7 +611,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
"""Run forward pass for BLIP-2.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@@ -631,26 +642,32 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
See also:
|
||||
:class:`Blip2ImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user