[Misc] Enhance attention selector (#4751)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@@ -215,6 +216,7 @@ class ArcticAttention(nn.Module):
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -265,7 +267,8 @@ class ArcticAttention(nn.Module):
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -288,6 +291,7 @@ class ArcticDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -297,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
|
||||
self.use_residual = config.use_residual and is_moe_layer
|
||||
self.self_attn = ArcticAttention(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
self.block_sparse_moe = ArcticMoE(
|
||||
config,
|
||||
@@ -356,6 +361,7 @@ class ArcticModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -366,7 +372,10 @@ class ArcticModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.vocab_size)
|
||||
self.layers = nn.ModuleList([
|
||||
ArcticDecoderLayer(config, layer_idx, quant_config=quant_config)
|
||||
ArcticDecoderLayer(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self._attn_implementation = config._attn_implementation
|
||||
@@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: ArcticConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = ArcticModel(config, quant_config)
|
||||
self.model = ArcticModel(config, cache_config, quant_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
|
||||
Reference in New Issue
Block a user