Refactor Attention (#1840)

This commit is contained in:
Woosuk Kwon
2023-11-29 15:37:31 -08:00
committed by GitHub
parent 0229c386c5
commit a9e4574261
16 changed files with 354 additions and 492 deletions

View File

@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
@@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
def forward(
self,