[Kernel] Move attn_type to Attention.__init__() (#11690)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-01-07 00:11:28 +08:00
committed by GitHub
parent 32c9eff2ff
commit e20c92bb61
18 changed files with 159 additions and 201 deletions

View File

@@ -107,7 +107,8 @@ class Qwen2Attention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "") -> None:
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -160,7 +161,8 @@ class Qwen2Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
attn_type=attn_type)
def forward(
self,
@@ -168,17 +170,11 @@ class Qwen2Attention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=attn_type)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
@@ -197,6 +193,16 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -207,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
@@ -220,15 +227,6 @@ class Qwen2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
self._attn_type = AttentionType.DECODER
else:
self._attn_type = AttentionType.ENCODER_ONLY
def forward(
self,
positions: torch.Tensor,
@@ -249,7 +247,6 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
attn_type=self._attn_type,
)
# Fully Connected