[Core] Refactor Attention Take 2 (#3462)

This commit is contained in:
Woosuk Kwon
2024-03-24 21:39:33 -07:00
committed by GitHub
parent b0dfa91dd7
commit 925f3332ca
47 changed files with 1268 additions and 1117 deletions

View File

@@ -18,15 +18,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
from typing import List, Optional
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module):
@@ -94,8 +91,8 @@ class GPTBigCodeAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
@@ -105,9 +102,7 @@ class GPTBigCodeAttention(nn.Module):
],
dim=-1,
)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
@@ -165,15 +160,15 @@ class GPTBigCodeBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# residual connection
hidden_states = attn_output + residual
@@ -211,8 +206,8 @@ class GPTBigCodeModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
@@ -220,7 +215,7 @@ class GPTBigCodeModel(nn.Module):
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@@ -245,11 +240,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,