[Core] Refactor Attention Take 2 (#3462)

This commit is contained in:
Woosuk Kwon
2024-03-24 21:39:33 -07:00
committed by GitHub
parent b0dfa91dd7
commit 925f3332ca
47 changed files with 1268 additions and 1117 deletions

View File

@@ -18,15 +18,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import List, Optional, Tuple
from typing import List, Optional
import torch
from torch import nn
from transformers import Starcoder2Config
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -43,8 +42,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Starcoder2Attention(nn.Module):
@@ -111,14 +108,13 @@ class Starcoder2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
@@ -171,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
@@ -181,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
@@ -217,14 +213,14 @@ class Starcoder2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i],
input_metadata)
attn_metadata)
hidden_states = self.norm(hidden_states)
return hidden_states
@@ -258,11 +254,11 @@ class Starcoder2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,