Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -21,14 +21,14 @@
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
if self.v1 or self.sliding_window:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
|
||||
hidden_states_attention = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
# Add everything together
|
||||
@@ -311,8 +305,6 @@ class CohereModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@@ -326,13 +318,10 @@ class CohereModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user