[Bugfix] Fix missing scale passing for encoder Triton Attention implementation (#32149)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -4,10 +4,7 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -23,11 +20,6 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@@ -4,10 +4,7 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -23,11 +20,6 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
|
||||
@@ -573,6 +573,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
b_seq_len=seq_lens,
|
||||
max_input_len=max_query_len,
|
||||
is_causal=False, # Encoder attention is bidirectional
|
||||
softmax_scale=self.scale,
|
||||
sliding_window_q=self.sliding_window[0],
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
|
||||
@@ -211,16 +211,17 @@ def get_block_size(dtype: torch.dtype) -> int:
|
||||
|
||||
|
||||
def context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_input_len,
|
||||
is_causal=True,
|
||||
sliding_window_q=None,
|
||||
sliding_window_k=None,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
b_start_loc: torch.Tensor,
|
||||
b_seq_len: torch.Tensor,
|
||||
max_input_len: int,
|
||||
is_causal: bool = True,
|
||||
softmax_scale: float | None = None,
|
||||
sliding_window_q: int | None = None,
|
||||
sliding_window_k: int | None = None,
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
@@ -232,7 +233,7 @@ def context_attention_fwd(
|
||||
|
||||
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user