[Misc] Enhance attention selector (#4751)

This commit is contained in:
Woosuk Kwon
2024-05-13 10:47:25 -07:00
committed by GitHub
parent e7c46b9527
commit 0fca3cdcf2
49 changed files with 573 additions and 220 deletions

View File

@@ -24,6 +24,7 @@ from torch import nn
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
@@ -71,6 +72,7 @@ class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -108,7 +110,8 @@ class BloomAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
cache_config=cache_config)
def forward(
self,
@@ -158,6 +161,7 @@ class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -165,7 +169,8 @@ class BloomBlock(nn.Module):
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, quant_config)
self.self_attention = BloomAttention(config, cache_config,
quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config)
@@ -214,6 +219,7 @@ class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -229,7 +235,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, quant_config)
BloomBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
@@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, quant_config)
self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()