[Model] GritLM supports other attention backends (#18109)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-14 18:33:19 +08:00
committed by GitHub
parent 259127f8b8
commit d62a076e84
4 changed files with 85 additions and 108 deletions

View File

@@ -28,7 +28,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -96,19 +96,22 @@ class LlamaMLP(nn.Module):
class LlamaAttention(nn.Module):
def __init__(self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
@@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
@@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias
# By default, Llama uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. parasail-ai/GritLM-7B-vllm)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
@@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,