Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.q_norm is not None and self.k_norm is not None:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Self Attention
|
||||
@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Self Attention
|
||||
@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# Apply blocks one-by-one.
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user