Add contributing guideline and mypy config (#122)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
@@ -17,7 +19,7 @@ _MODEL_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
|
||||
Reference in New Issue
Block a user