Fix trtllm-gen attention env and add attention sink (#22378)

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Lain <fusiyuan2000@hotmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Lain
2025-08-06 18:07:41 -07:00
committed by GitHub
parent 5c7cc33f4d
commit 9a3835aaa9
5 changed files with 21 additions and 28 deletions

View File

@@ -70,9 +70,8 @@ class OAIAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
attention_sink_dtype = (
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
else torch.bfloat16)
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype,