[Bugfix] JAIS: Only apply ALiBi when position_embedding_type='alibi' (#37820)

Co-authored-by: r266-tech <r266-tech@users.noreply.github.com>
This commit is contained in:
r266-tech
2026-03-23 15:36:34 +08:00
committed by GitHub
parent 410d300893
commit 02e6efe56d

View File

@@ -117,11 +117,14 @@ class JAISAttention(nn.Module):
prefix=f"{prefix}.c_proj",
)
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.use_alibi = config.position_embedding_type == "alibi"
alibi_slopes = None
if self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(
self.num_heads,
self.head_dim,