[Misc] Enhance attention selector (#4751)

This commit is contained in:
Woosuk Kwon
2024-05-13 10:47:25 -07:00
committed by GitHub
parent e7c46b9527
commit 0fca3cdcf2
49 changed files with 573 additions and 220 deletions

View File

@@ -24,6 +24,7 @@ from torch import nn
from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -45,6 +46,7 @@ class GPT2Attention(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -70,7 +72,10 @@ class GPT2Attention(nn.Module):
bias=True,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
cache_config=cache_config)
def forward(
self,
@@ -122,6 +127,7 @@ class GPT2Block(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -130,7 +136,7 @@ class GPT2Block(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, quant_config)
self.attn = GPT2Attention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config)
@@ -163,6 +169,7 @@ class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -174,7 +181,7 @@ class GPT2Model(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPT2Block(config, quant_config)
GPT2Block(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config, quant_config)
self.transformer = GPT2Model(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()