[Attention] Use FA4 for MLA prefill (#34732)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -1,4 +1,19 @@
|
||||
# MLA prefill-only benchmark configuration for sparse backends
|
||||
# MLA prefill backend comparison
|
||||
#
|
||||
# Compares all available MLA prefill backends:
|
||||
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
|
||||
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
|
||||
#
|
||||
# Uses cutlass_mla as the decode backend for impl construction
|
||||
# (only the prefill path is exercised).
|
||||
#
|
||||
# Backends that aren't available on the current platform will report errors
|
||||
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
|
||||
#
|
||||
# Usage:
|
||||
# python benchmark.py --config configs/mla_prefill.yaml
|
||||
|
||||
description: "MLA prefill backend comparison"
|
||||
|
||||
model:
|
||||
name: "deepseek-v3"
|
||||
@@ -12,20 +27,25 @@ model:
|
||||
v_head_dim: 128
|
||||
block_size: 128
|
||||
|
||||
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
|
||||
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
|
||||
model_parameter_sweep:
|
||||
param_name: "num_q_heads"
|
||||
values: [128, 64, 32, 16]
|
||||
label_format: "{backend}_{value}h"
|
||||
# model:
|
||||
# name: "deepseek-v2-lite"
|
||||
# num_layers: 27
|
||||
# num_q_heads: 16
|
||||
# num_kv_heads: 1
|
||||
# head_dim: 576
|
||||
# kv_lora_rank: 512
|
||||
# qk_nope_head_dim: 128
|
||||
# qk_rope_head_dim: 64
|
||||
# v_head_dim: 128
|
||||
# block_size: 128
|
||||
|
||||
batch_specs:
|
||||
# Pure prefill
|
||||
- "1q512"
|
||||
- "1q1k"
|
||||
- "1q2k"
|
||||
- "1q4k"
|
||||
- "1q8k"
|
||||
- "q512"
|
||||
- "q1k"
|
||||
- "q2k"
|
||||
- "q4k"
|
||||
- "q8k"
|
||||
|
||||
# Batched pure prefill
|
||||
- "2q512"
|
||||
@@ -44,19 +64,63 @@ batch_specs:
|
||||
- "8q4k"
|
||||
- "8q8k"
|
||||
|
||||
# Extend
|
||||
- "1q512s4k"
|
||||
- "1q512s8k"
|
||||
- "1q1ks8k"
|
||||
- "1q2ks8k"
|
||||
- "1q2ks16k"
|
||||
- "1q4ks16k"
|
||||
# Chunked prefill / extend
|
||||
# Short context
|
||||
- "q128s1k"
|
||||
- "q256s2k"
|
||||
- "q512s4k"
|
||||
- "q1ks4k"
|
||||
- "q2ks8k"
|
||||
- "2q128s1k"
|
||||
- "2q256s2k"
|
||||
- "2q512s4k"
|
||||
- "2q1ks4k"
|
||||
- "2q2ks8k"
|
||||
- "4q128s1k"
|
||||
- "4q256s2k"
|
||||
- "4q512s4k"
|
||||
- "4q1ks4k"
|
||||
- "4q2ks8k"
|
||||
- "8q128s1k"
|
||||
- "8q256s2k"
|
||||
- "8q512s4k"
|
||||
- "8q1ks4k"
|
||||
|
||||
backends:
|
||||
- FLASHMLA_SPARSE
|
||||
- FLASHINFER_MLA_SPARSE
|
||||
# Medium context
|
||||
- "q128s16k"
|
||||
- "q512s16k"
|
||||
- "q1ks16k"
|
||||
- "q2ks16k"
|
||||
- "2q128s16k"
|
||||
- "2q512s16k"
|
||||
- "2q1ks16k"
|
||||
- "2q2ks16k"
|
||||
- "4q128s16k"
|
||||
- "4q512s16k"
|
||||
- "4q1ks16k"
|
||||
- "4q2ks16k"
|
||||
|
||||
# Long context
|
||||
- "q128s64k"
|
||||
- "q512s64k"
|
||||
- "q1ks64k"
|
||||
- "q2ks64k"
|
||||
- "2q128s64k"
|
||||
- "2q512s64k"
|
||||
- "2q1ks64k"
|
||||
- "2q2ks64k"
|
||||
|
||||
decode_backends:
|
||||
- CUTLASS_MLA
|
||||
|
||||
prefill_backends:
|
||||
- fa2
|
||||
- fa3
|
||||
- fa4
|
||||
- flashinfer
|
||||
- cudnn
|
||||
- trtllm
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 10
|
||||
warmup_iters: 3
|
||||
profile_memory: true
|
||||
repeats: 20
|
||||
warmup_iters: 5
|
||||
|
||||
Reference in New Issue
Block a user