[model][utils] add extract_layer_index utility function (#10599)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-23 22:22:54 -08:00
committed by GitHub
parent eda2b3589c
commit c055747867
6 changed files with 59 additions and 51 deletions

View File

@@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return name if not prefix else f"{prefix}.{name}"
def extract_layer_index(layer_name: str) -> int:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: List[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))
except ValueError:
continue
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]