Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -20,13 +20,13 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import StableLmConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
@@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
@@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
@@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
@@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states