Fix trtllm-gen attention env and add attention sink (#22378)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Lain <fusiyuan2000@hotmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -70,9 +70,8 @@ class OAIAttention(nn.Module):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attention_sink_dtype = (
|
||||
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
|
||||
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
else torch.bfloat16)
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
|
||||
Reference in New Issue
Block a user