[Misc] Enhance attention selector (#4751)
This commit is contained in:
@@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -72,6 +73,7 @@ class StablelmAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -124,7 +126,8 @@ class StablelmAttention(nn.Module):
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_key_value_heads)
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
cache_config=cache_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = StablelmAttention(config)
|
||||
self.self_attn = StablelmAttention(config, cache_config, quant_config)
|
||||
self.mlp = StablelmMLP(config, quant_config)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
@@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
StablelmDecoderLayer(config, quant_config)
|
||||
StablelmDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
@@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = StableLMEpochModel(config, quant_config)
|
||||
self.model = StableLMEpochModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
Reference in New Issue
Block a user