Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata,
mamba_cache_params)
hidden_states = self.mamba(hidden_states, mamba_cache_params)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
**kwargs,
):
@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
hidden_states = self.self_attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
@@ -320,8 +311,6 @@ class JambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
@@ -339,12 +328,9 @@ class JambaModel(nn.Module):
kv_cache_index = 0
mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kv_cache = None
for layer in self.layers[self.start_layer:self.end_layer]:
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = mamba_cache_index
@@ -355,8 +341,6 @@ class JambaModel(nn.Module):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank:
@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params,
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states