Refactor Attention (#1840)
This commit is contained in:
@@ -10,12 +10,13 @@ from torch.nn import LayerNorm
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||
max_positions = getattr(config, "seq_length", 8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim // 2,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
max_position=max_positions,
|
||||
base=10000 * rope_ratio,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = PagedAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
context_layer = self.attn(
|
||||
position_ids,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user