Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
@@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(len(self.layer)):
layer = self.layer[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
for layer in self.layer:
hidden_states = layer(hidden_states)
return hidden_states
@@ -152,13 +150,8 @@ class BertLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
):
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
def forward(self, hidden_states: torch.Tensor):
attn_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
@@ -191,10 +184,8 @@ class BertAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
self_output = self.self(hidden_states, kv_cache, attn_metadata)
self_output = self.self(hidden_states)
return self.output(self_output, hidden_states)
@@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.attn(q, k, v)
return output
@@ -343,8 +332,6 @@ class BertModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
@@ -352,13 +339,14 @@ class BertModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states, kv_caches, attn_metadata)
return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
@@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata)
intermediate_tensors=intermediate_tensors)
def pooler(
self,
@@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.bert(input_ids=input_ids,
position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids)