From 02e6efe56d545d50d342356630430b2ad9583ca9 Mon Sep 17 00:00:00 2001 From: r266-tech Date: Mon, 23 Mar 2026 15:36:34 +0800 Subject: [PATCH] [Bugfix] JAIS: Only apply ALiBi when position_embedding_type='alibi' (#37820) Co-authored-by: r266-tech --- vllm/model_executor/models/jais.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 572717b51..b8c1310cb 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -117,11 +117,14 @@ class JAISAttention(nn.Module): prefix=f"{prefix}.c_proj", ) - tp_rank = get_tensor_model_parallel_rank() - head_start = tp_rank * self.num_heads - head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(total_num_heads) - alibi_slopes = alibi_slopes[head_start:head_end] + self.use_alibi = config.position_embedding_type == "alibi" + alibi_slopes = None + if self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end] self.attn = Attention( self.num_heads, self.head_dim,