Refactor sliding window configuration to Transformers best practice (#21927)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -182,21 +182,13 @@ class CohereAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Model v2 has interleaved sliding windows, v1 does not
|
||||
interleaved_sliding_window = getattr(config,
|
||||
"interleaved_sliding_window",
|
||||
None)
|
||||
self.v1 = interleaved_sliding_window is None
|
||||
self.v1 = isinstance(config, CohereConfig)
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
layer_has_sliding_window = (
|
||||
getattr(config, "sliding_window_pattern", False) and
|
||||
(layer_idx + 1) % self.config.sliding_window_pattern
|
||||
!= 0) or (getattr(config, "layer_types", False)
|
||||
and config.layer_types[layer_idx] == "sliding_attention")
|
||||
|
||||
self.sliding_window = (interleaved_sliding_window
|
||||
or config.sliding_window
|
||||
if layer_has_sliding_window else None)
|
||||
self.sliding_window = None
|
||||
if not self.v1:
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
if config.layer_types[layer_idx] == "sliding_attention":
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
|
||||
Reference in New Issue
Block a user