[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -20,7 +20,7 @@
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -352,7 +352,7 @@ class CohereForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
|
||||
Reference in New Issue
Block a user