[model][utils] add extract_layer_index utility function (#10599)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
|
||||
The string "prefix.name" if prefix was non-empty, otherwise just "name".
|
||||
"""
|
||||
return name if not prefix else f"{prefix}.{name}"
|
||||
|
||||
|
||||
def extract_layer_index(layer_name: str) -> int:
|
||||
"""
|
||||
Extract the layer index from the module name.
|
||||
Examples:
|
||||
- "encoder.layers.0" -> 0
|
||||
- "encoder.layers.1.self_attn" -> 1
|
||||
- "2.self_attn" -> 2
|
||||
- "model.encoder.layers.0.sub.1" -> ValueError
|
||||
"""
|
||||
subnames = layer_name.split(".")
|
||||
int_vals: List[int] = []
|
||||
for subname in subnames:
|
||||
try:
|
||||
int_vals.append(int(subname))
|
||||
except ValueError:
|
||||
continue
|
||||
assert len(int_vals) == 1, (f"layer name {layer_name} should"
|
||||
" only contain one integer")
|
||||
return int_vals[0]
|
||||
|
||||
Reference in New Issue
Block a user