[Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from collections import UserDict
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
||||
Union, overload)
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
|
||||
Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -329,3 +329,21 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
})
|
||||
|
||||
return make_empty_intermediate_tensors
|
||||
|
||||
|
||||
class LLMWrapper(nn.Module):
|
||||
"""
|
||||
To align with the key names of LoRA trained with PEFT, we need to add an
|
||||
additional layer to the llm's implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: nn.Module, name: str) -> None:
|
||||
super().__init__()
|
||||
self.model_name = name
|
||||
setattr(self, name, llm)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name)(*args, **kwargs)
|
||||
|
||||
def embed_tokens(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user