diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 90793fb61..eeb7137ff 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -4,10 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.config import AttentionConfig -from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.v1.attention.backends.registry import AttentionBackendEnum def parse_args(): @@ -23,11 +20,6 @@ def parse_args(): def main(args: Namespace): - if current_platform.is_rocm(): - args.attention_config = AttentionConfig( - backend=AttentionBackendEnum.FLEX_ATTENTION - ) - # Sample prompts. prompts = [ "Hello, my name is", diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index abe827043..cbca50eb5 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -4,10 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.config import AttentionConfig -from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.v1.attention.backends.registry import AttentionBackendEnum def parse_args(): @@ -23,11 +20,6 @@ def parse_args(): def main(args: Namespace): - if current_platform.is_rocm(): - args.attention_config = AttentionConfig( - backend=AttentionBackendEnum.FLEX_ATTENTION - ) - # Sample prompts. text_1 = "What is the capital of France?" texts_2 = [ diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ed2f9564e..3ef5b4a22 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -573,6 +573,7 @@ class TritonAttentionImpl(AttentionImpl): b_seq_len=seq_lens, max_input_len=max_query_len, is_causal=False, # Encoder attention is bidirectional + softmax_scale=self.scale, sliding_window_q=self.sliding_window[0], sliding_window_k=self.sliding_window[1], ) diff --git a/vllm/v1/attention/ops/triton_prefill_attention.py b/vllm/v1/attention/ops/triton_prefill_attention.py index c593698f1..046d0c170 100644 --- a/vllm/v1/attention/ops/triton_prefill_attention.py +++ b/vllm/v1/attention/ops/triton_prefill_attention.py @@ -211,16 +211,17 @@ def get_block_size(dtype: torch.dtype) -> int: def context_attention_fwd( - q, - k, - v, - o, - b_start_loc, - b_seq_len, - max_input_len, - is_causal=True, - sliding_window_q=None, - sliding_window_k=None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + sliding_window_q: int | None = None, + sliding_window_k: int | None = None, ): """ q, k, v: [b * s, head, head_dim] @@ -232,7 +233,7 @@ def context_attention_fwd( Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1]