[Misc] Enhance attention selector (#4751)

This commit is contained in:
Woosuk Kwon
2024-05-13 10:47:25 -07:00
committed by GitHub
parent e7c46b9527
commit 0fca3cdcf2
49 changed files with 573 additions and 220 deletions

View File

@@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -72,6 +73,7 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
@@ -124,7 +126,8 @@ class StablelmAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config)
def forward(
self,
@@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.self_attn = StablelmAttention(config, cache_config, quant_config)
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
@@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
@@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, quant_config)
StablelmDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
@@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config)
self.model = StableLMEpochModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()