[Models] Improve iteration over layers (#26425)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
"""PyTorch MAMBA model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -162,8 +163,7 @@ class MambaModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions, hidden_states=hidden_states, residual=residual
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user