[Model] Support is_causal HF config field for Qwen2 model (#10621)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-25 17:51:20 +08:00
committed by GitHub
parent 05d1f8c9c6
commit ed46f14321
5 changed files with 51 additions and 13 deletions

View File

@@ -27,7 +27,7 @@ import torch
from torch import nn
from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=attn_type)
output, _ = self.o_proj(attn_output)
return output
@@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
self._attn_type = AttentionType.DECODER
else:
self._attn_type = AttentionType.ENCODER_ONLY
def forward(
self,
positions: torch.Tensor,
@@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
attn_type=self._attn_type,
)
# Fully Connected