Refactor Attention (#1840)

This commit is contained in:
Woosuk Kwon
2023-11-29 15:37:31 -08:00
committed by GitHub
parent 0229c386c5
commit a9e4574261
16 changed files with 354 additions and 492 deletions

View File

@@ -29,12 +29,13 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
bias=False,
linear_method=linear_method,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.rotary_emb = get_rope(
self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling)
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.o_proj(attn_output)
return output