[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -41,8 +41,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal
from .utils import flatten_bn, merge_multimodal_embeddings
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
@@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: FuyuConfig,
@@ -242,6 +243,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
@@ -297,23 +304,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
@@ -336,34 +349,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision embeddings
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
for name, loaded_weight in weights_group["vision_embed_tokens"]:
param = vision_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])