[Models] Improve iteration over layers (#26425)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2025-10-08 21:48:33 +01:00
committed by GitHub
parent 4ebc9108a7
commit 93f2c0aa08
8 changed files with 23 additions and 22 deletions

View File

@@ -3,6 +3,7 @@
"""PyTorch MAMBA model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
@@ -162,8 +163,7 @@ class MambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions, hidden_states=hidden_states, residual=residual
)