[model][utils] add extract_layer_index utility function (#10599)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-23 22:22:54 -08:00
committed by GitHub
parent eda2b3589c
commit c055747867
6 changed files with 59 additions and 51 deletions

View File

@@ -42,7 +42,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -85,7 +86,6 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module):
def __init__(self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
@@ -98,7 +98,6 @@ class Gemma2Attention(nn.Module):
attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -145,6 +144,7 @@ class Gemma2Attention(nn.Module):
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
@@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
@@ -187,7 +186,6 @@ class Gemma2DecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -262,11 +260,8 @@ class Gemma2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
config,
cache_config,
quant_config,
prefix=prefix),
lambda prefix: Gemma2DecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)