Add Model Revision Support (#1014)
Co-authored-by: Jasmond Loh <Jasmond.Loh@hotmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
|
||||
Reference in New Issue
Block a user