[Core] Refactor Attention Take 2 (#3462)

This commit is contained in:
Woosuk Kwon
2024-03-24 21:39:33 -07:00
committed by GitHub
parent b0dfa91dd7
commit 925f3332ca
47 changed files with 1268 additions and 1117 deletions

View File

@@ -17,15 +17,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
from typing import List, Optional
import torch
from torch import nn
from transformers import OPTConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -97,14 +94,12 @@ class OPTAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
@@ -152,8 +147,8 @@ class OPTDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
@@ -162,7 +157,7 @@ class OPTDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata)
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
@@ -241,8 +236,8 @@ class OPTDecoder(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
@@ -252,7 +247,7 @@ class OPTDecoder(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
@@ -275,10 +270,10 @@ class OPTModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, input_metadata)
return self.decoder(input_ids, positions, kv_caches, attn_metadata)
class OPTForCausalLM(nn.Module):
@@ -300,11 +295,11 @@ class OPTForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,