Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -20,14 +20,14 @@
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@@ -190,8 +190,6 @@ class FalconAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
@@ -199,7 +197,7 @@ class FalconAttention(nn.Module):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_rotary:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
@@ -291,8 +289,6 @@ class FalconDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
@@ -306,8 +302,6 @@ class FalconDecoderLayer(nn.Module):
|
||||
attention_output, attention_bias = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=attention_layernorm_out,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||
attention_output += attention_bias
|
||||
@@ -384,8 +378,6 @@ class FalconModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@@ -396,14 +388,8 @@ class FalconModel(nn.Module):
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@@ -450,14 +436,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user