[Bugfix] Fix missing scale passing for encoder Triton Attention implementation (#32149)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -4,10 +4,7 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -23,11 +20,6 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
Reference in New Issue
Block a user