[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -41,8 +41,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
config: FuyuConfig,
|
||||
@@ -242,6 +243,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
self.language_model = PersimmonForCausalLM(config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -297,23 +304,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
|
||||
else:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.embed_tokens(
|
||||
input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
@@ -336,34 +349,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# copy from vllm/model_executor/models/bloom.py
|
||||
# NOTE: Fuyu's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision embeddings
|
||||
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
|
||||
for name, loaded_weight in weights_group["vision_embed_tokens"]:
|
||||
param = vision_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
|
||||
Reference in New Issue
Block a user