[Kernel] Move attn_type to Attention.__init__() (#11690)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -238,7 +238,8 @@ class BertSelfAttention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -248,12 +249,7 @@ class BertSelfAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user