Re-enable the 80 char line width limit (#3305)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,7 +17,8 @@
|
||||
# This code is based off the following work:
|
||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
|
||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
|
||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
||||
model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -102,9 +104,9 @@ class StablelmAttention(nn.Module):
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim
|
||||
self.qkv_bias = getattr(config, "use_qkv_bias", False)
|
||||
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads}).")
|
||||
raise ValueError(f"hidden_size must be divisible by num_heads "
|
||||
f"(got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads}).")
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(self.hidden_size,
|
||||
self.head_dim,
|
||||
@@ -192,7 +194,6 @@ class StableLMEpochModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None) -> None:
|
||||
super().__init__()
|
||||
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user