[Models] Improve iteration over layers (#19497)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -260,7 +261,7 @@ class MPTModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for block in self.blocks[self.start_layer:self.end_layer]:
|
||||
for block in islice(self.blocks, self.start_layer, self.end_layer):
|
||||
hidden_states = block(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
Reference in New Issue
Block a user