Refactor Attention (#1840)
This commit is contained in:
@@ -26,13 +26,13 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
else:
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings)
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.postion_embedding == "ALIBI":
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
else:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user