[Model] Clean up MiniCPMV (#10751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Protocol, Set, Tuple, Union, overload)
|
||||
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -560,30 +560,6 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
return make_empty_intermediate_tensors
|
||||
|
||||
|
||||
class LLMWrapper(nn.Module):
|
||||
"""
|
||||
To align with the key names of LoRA trained with PEFT, we need to add an
|
||||
additional layer to the llm's implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: nn.Module, name: str) -> None:
|
||||
super().__init__()
|
||||
self.model_name = name
|
||||
setattr(self, name, llm)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
llm = super().__getattr__(self.model_name)
|
||||
if key == self.model_name:
|
||||
return llm
|
||||
|
||||
return getattr(llm, key)
|
||||
|
||||
# We need to explicitly override this
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
llm = super().__getattr__(self.model_name)
|
||||
return llm(*args, **kwargs)
|
||||
|
||||
|
||||
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
|
||||
Reference in New Issue
Block a user