[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional
import torch
@@ -9,13 +9,13 @@ class DummyLoRAManager:
def __init__(self):
super().__init__()
self._loras = {}
self._loras: Dict[str, LoRALayerWeights] = {}
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
return self._loras.get(module_name, None)
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
return self._loras[module_name]
def init_random_lora(self,
module_name: str,
@@ -68,11 +68,11 @@ class DummyLoRAManager:
module_name: str,
input_dim: int,
output_dims: List[int],
noop_lora_index: List[int] = None,
rank=8,
noop_lora_index: Optional[List[int]] = None,
rank: int = 8,
):
base_loras = []
noop_lora_index = set(noop_lora_index or [])
base_loras: List[LoRALayerWeights] = []
noop_lora_index_set = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims):
base_lora = self.init_lora(
@@ -80,7 +80,7 @@ class DummyLoRAManager:
input_dim,
out_dim,
rank=rank,
noop=i in noop_lora_index,
noop=i in noop_lora_index_set,
)
base_loras.append(base_lora)
packed_lora = PackedLoRALayerWeights.pack(base_loras)