Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
from torch import nn
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor],
@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params, sequence_idx)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx)
return hidden_states, residual
@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states