[Misc] Enhance attention selector (#4751)

This commit is contained in:
Woosuk Kwon
2024-05-13 10:47:25 -07:00
committed by GitHub
parent e7c46b9527
commit 0fca3cdcf2
49 changed files with 573 additions and 220 deletions

View File

@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
@@ -43,6 +44,7 @@ class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -107,7 +109,8 @@ class MPTAttention(nn.Module):
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
@@ -166,12 +169,13 @@ class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, quant_config)
self.attn = MPTAttention(config, cache_config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
@@ -201,6 +205,7 @@ class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -211,8 +216,10 @@ class MPTModel(nn.Module):
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.blocks = nn.ModuleList([
MPTBlock(config, cache_config, quant_config)
for _ in range(config.n_layers)
])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
@@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -253,7 +261,7 @@ class MPTForCausalLM(nn.Module):
assert config.tie_word_embeddings
self.quant_config = quant_config
self.transformer = MPTModel(config, quant_config)
self.transformer = MPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()