diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 572717b51..b8c1310cb 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -117,11 +117,14 @@ class JAISAttention(nn.Module): prefix=f"{prefix}.c_proj", ) - tp_rank = get_tensor_model_parallel_rank() - head_start = tp_rank * self.num_heads - head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(total_num_heads) - alibi_slopes = alibi_slopes[head_start:head_end] + self.use_alibi = config.position_embedding_type == "alibi" + alibi_slopes = None + if self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end] self.attn = Attention( self.num_heads, self.head_dim,