[Core] Refactor Attention Take 2 (#3462)

This commit is contained in:
Woosuk Kwon
2024-03-24 21:39:33 -07:00
committed by GitHub
parent b0dfa91dd7
commit 925f3332ca
47 changed files with 1268 additions and 1117 deletions

View File

@@ -1,14 +1,13 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import List, Optional, Tuple
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(
total_num_heads: int,
@@ -116,8 +113,8 @@ class MPTAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states)
@@ -127,8 +124,7 @@ class MPTAttention(nn.Module):
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
@@ -184,15 +180,15 @@ class MPTBlock(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=x,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
@@ -230,8 +226,8 @@ class MPTModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
@@ -240,7 +236,7 @@ class MPTModel(nn.Module):
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
attn_metadata,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states
@@ -267,11 +263,11 @@ class MPTForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,