[Model] Support is_causal HF config field for Qwen2 model (#10621)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -27,7 +27,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
self._attn_type = AttentionType.DECODER
|
||||
else:
|
||||
self._attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
attn_type=self._attn_type,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
|
||||
Reference in New Issue
Block a user