Upgrade transformers version to 4.36.0 (#2046)
This commit is contained in:
@@ -29,7 +29,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from transformers import MistralConfig
|
||||
from transformers import MixtralConfig
|
||||
|
||||
try:
|
||||
import megablocks.ops as ops
|
||||
@@ -395,7 +395,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
config: MixtralConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -443,7 +443,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user