Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -24,12 +24,12 @@
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
from functools import partial
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
@@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Attention block.
residual = hidden_states
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
attn_metadata)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual
@@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
"""
@@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
assert isinstance(hidden_states, torch.Tensor)
# Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer):
for layer in self.layers[self.start_layer:self.end_layer]:
# shape: (batch_size, seq_len, d_model)
hidden_states = self.layers[i](
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
@@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
)
return hidden_states