[Model] Add Idefics3 support (#9767)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: B-201 <Joy25810@foxmail.com>
Co-authored-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
Jee Jee Li
2024-11-06 19:41:17 +08:00
committed by GitHub
parent 2003cc3513
commit a5bba7d234
8 changed files with 723 additions and 1 deletions

View File

@@ -15,7 +15,7 @@
# limitations under the License.
"""PyTorch Idefics2 model."""
from typing import Optional
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
@@ -29,6 +29,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
class Idefics2VisionEmbeddings(nn.Module):
@@ -329,3 +330,25 @@ class Idefics2VisionTransformer(nn.Module):
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)