[core] gemma2 full context length support (#10584)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -143,12 +143,12 @@ class Gemma2Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||
# odd layer, vLLM currently ignores it and uses global attention for
|
||||
# all layers.
|
||||
use_sliding_window = (layer_idx % 2 == 1
|
||||
and config.sliding_window is not None)
|
||||
del use_sliding_window # Unused.
|
||||
# reference:
|
||||
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
||||
use_sliding_window = (layer_idx % 2 == 0 and
|
||||
config.interleaved_sliding_window is not None)
|
||||
sliding_window = config.interleaved_sliding_window if \
|
||||
use_sliding_window else None
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@@ -156,6 +156,7 @@ class Gemma2Attention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user