[Model] Support Cohere2ForCausalLM (Cohere R7B) (#11203)
This commit is contained in:
@@ -48,7 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@@ -171,12 +171,26 @@ class CohereAttention(nn.Module):
|
||||
rope_scaling=self.rope_scaling,
|
||||
is_neox_style=False,
|
||||
)
|
||||
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
# Model v2 has sliding windows, v1 does not
|
||||
self.v1 = sliding_window is None
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
layer_has_sliding_window = (
|
||||
getattr(config, "sliding_window_pattern", False)
|
||||
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
|
||||
|
||||
self.sliding_window = (sliding_window
|
||||
if layer_has_sliding_window else None)
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=self.sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||
@@ -206,7 +220,8 @@ class CohereAttention(nn.Module):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.v1 or self.sliding_window:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user