[Misc] Enhance attention selector (#4751)

This commit is contained in:
Woosuk Kwon
2024-05-13 10:47:25 -07:00
committed by GitHub
parent e7c46b9527
commit 0fca3cdcf2
49 changed files with 573 additions and 220 deletions

View File

@@ -42,6 +42,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -63,6 +64,7 @@ class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
@@ -105,7 +107,10 @@ class PhiAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward(
self,
@@ -155,11 +160,12 @@ class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, quant_config)
self.self_attn = PhiAttention(config, cache_config, quant_config)
self.mlp = PhiMLP(config, quant_config)
def forward(
@@ -186,6 +192,7 @@ class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
@@ -193,7 +200,7 @@ class PhiModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, quant_config)
PhiLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
@@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = PhiModel(config, quant_config)
self.model = PhiModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,