[Bugfix] Fix missing scale passing for encoder Triton Attention implementation (#32149)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-12 19:13:41 +08:00
committed by GitHub
parent a5f89ae296
commit 9dbe1fe960
4 changed files with 13 additions and 27 deletions

View File

@@ -4,10 +4,7 @@
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.backends.registry import AttentionBackendEnum
def parse_args():
@@ -23,11 +20,6 @@ def parse_args():
def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts.
prompts = [
"Hello, my name is",

View File

@@ -4,10 +4,7 @@
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.backends.registry import AttentionBackendEnum
def parse_args():
@@ -23,11 +20,6 @@ def parse_args():
def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts.
text_1 = "What is the capital of France?"
texts_2 = [

View File

@@ -573,6 +573,7 @@ class TritonAttentionImpl(AttentionImpl):
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False, # Encoder attention is bidirectional
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)

View File

@@ -211,16 +211,17 @@ def get_block_size(dtype: torch.dtype) -> int:
def context_attention_fwd(
q,
k,
v,
o,
b_start_loc,
b_seq_len,
max_input_len,
is_causal=True,
sliding_window_q=None,
sliding_window_k=None,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_len: torch.Tensor,
max_input_len: int,
is_causal: bool = True,
softmax_scale: float | None = None,
sliding_window_q: int | None = None,
sliding_window_k: int | None = None,
):
"""
q, k, v: [b * s, head, head_dim]
@@ -232,7 +233,7 @@ def context_attention_fwd(
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5)
sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]