[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)
This commit is contained in:
@@ -10,11 +10,9 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.opt import OPTModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
@@ -22,12 +20,8 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
}
|
||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
@@ -491,9 +485,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# currently all existing BLIP-2 models have `tie_word_embeddings`
|
||||
# enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
@@ -514,17 +505,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = OPTModel(config.text_config, cache_config,
|
||||
quant_config)
|
||||
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.language_model.decoder.embed_tokens
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
@@ -653,7 +635,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
@@ -663,11 +646,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -676,56 +659,46 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.get_lm_head(), hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# only doing this for language model part for now.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_model is not None:
|
||||
# BlipVisionModel does not need sharding
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for (param_name, weight_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
# load vision encoder
|
||||
self.vision_model.load_weights(weights_group["vision_model"])
|
||||
|
||||
# load query tokens
|
||||
for name, loaded_weight in weights_group["query_tokens"]:
|
||||
assert name == ""
|
||||
param = self.query_tokens
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load qformer
|
||||
qformer_params_dict = dict(self.qformer.named_parameters())
|
||||
for name, loaded_weight in weights_group["qformer"]:
|
||||
param = qformer_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.language_projection.named_parameters())
|
||||
for name, loaded_weight in weights_group["language_projection"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
|
||||
Reference in New Issue
Block a user