Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -18,13 +18,13 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
@@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
|
||||
attention_output = self.self_attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attention_output = attention_output + residual
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
@@ -266,8 +260,6 @@ class BloomModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@@ -279,14 +271,8 @@ class BloomModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user