Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
@@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
@@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + x
residual = hidden_states
@@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
@@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
for block in self.blocks[self.start_layer:self.end_layer]:
hidden_states = block(position_ids, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states)
@@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(