[Models] Improve iteration over layers (#19497)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2025-08-29 02:26:34 +01:00
committed by GitHub
parent 235c9db8a7
commit de533ab2a1
65 changed files with 129 additions and 83 deletions

View File

@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
@@ -383,7 +384,7 @@ class LlamaModel(nn.Module):
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)