[Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199)

This commit is contained in:
Jee Jee Li
2024-09-29 14:59:45 +08:00
committed by GitHub
parent bc2ef1f77c
commit 3d49776bbb
8 changed files with 377 additions and 30 deletions

View File

@@ -1,7 +1,7 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
Tuple, Union, overload)
import torch
import torch.nn as nn
@@ -329,3 +329,21 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
})
return make_empty_intermediate_tensors
class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
def __init__(self, llm: nn.Module, name: str) -> None:
super().__init__()
self.model_name = name
setattr(self, name, llm)
def forward(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name)(*args, **kwargs)
def embed_tokens(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)