Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""PyTorch MAMBA2 model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor],
|
||||
@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||
mamba_cache_params, sequence_idx)
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
sequence_idx)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer),
|
||||
@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user