Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Inference-only Bamba model."""
|
||||
# Added by the IBM Team, 2024
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
@@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
||||
mamba_cache_params, sequence_idx)
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
sequence_idx)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
@@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
@@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
@@ -312,8 +305,6 @@ class BambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@@ -323,6 +314,7 @@ class BambaModel(nn.Module):
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
@@ -348,9 +340,7 @@ class BambaModel(nn.Module):
|
||||
num_attn = 0
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
if isinstance(layer, BambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[num_attn]
|
||||
num_attn += 1
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
@@ -361,8 +351,6 @@ class BambaModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
sequence_idx=seq_idx,
|
||||
@@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
@@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user