[Model] GritLM supports other attention backends (#18109)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -28,7 +28,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@@ -96,19 +96,22 @@ class LlamaMLP(nn.Module):
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
@@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
@@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
|
||||
if hasattr(config, 'qkv_bias'):
|
||||
attention_bias = config.qkv_bias
|
||||
|
||||
# By default, Llama uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. parasail-ai/GritLM-7B-vllm)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = LlamaAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
bias_o_proj=bias_o_proj,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user