2026-03-12 12:10:17 -04:00
|
|
|
# MLA prefill backend comparison
|
|
|
|
|
#
|
|
|
|
|
# Compares all available MLA prefill backends:
|
|
|
|
|
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
|
|
|
|
|
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
|
|
|
|
|
#
|
|
|
|
|
# Uses cutlass_mla as the decode backend for impl construction
|
|
|
|
|
# (only the prefill path is exercised).
|
|
|
|
|
#
|
|
|
|
|
# Backends that aren't available on the current platform will report errors
|
|
|
|
|
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
|
|
|
|
|
#
|
|
|
|
|
# Usage:
|
|
|
|
|
# python benchmark.py --config configs/mla_prefill.yaml
|
|
|
|
|
|
|
|
|
|
description: "MLA prefill backend comparison"
|
2026-02-12 12:21:54 -05:00
|
|
|
|
|
|
|
|
model:
|
|
|
|
|
name: "deepseek-v3"
|
|
|
|
|
num_layers: 60
|
|
|
|
|
num_q_heads: 128
|
|
|
|
|
num_kv_heads: 1
|
|
|
|
|
head_dim: 576
|
|
|
|
|
kv_lora_rank: 512
|
|
|
|
|
qk_nope_head_dim: 128
|
|
|
|
|
qk_rope_head_dim: 64
|
|
|
|
|
v_head_dim: 128
|
|
|
|
|
block_size: 128
|
|
|
|
|
|
2026-03-12 12:10:17 -04:00
|
|
|
# model:
|
|
|
|
|
# name: "deepseek-v2-lite"
|
|
|
|
|
# num_layers: 27
|
|
|
|
|
# num_q_heads: 16
|
|
|
|
|
# num_kv_heads: 1
|
|
|
|
|
# head_dim: 576
|
|
|
|
|
# kv_lora_rank: 512
|
|
|
|
|
# qk_nope_head_dim: 128
|
|
|
|
|
# qk_rope_head_dim: 64
|
|
|
|
|
# v_head_dim: 128
|
|
|
|
|
# block_size: 128
|
2026-02-12 12:21:54 -05:00
|
|
|
|
|
|
|
|
batch_specs:
|
|
|
|
|
# Pure prefill
|
2026-03-12 12:10:17 -04:00
|
|
|
- "q512"
|
|
|
|
|
- "q1k"
|
|
|
|
|
- "q2k"
|
|
|
|
|
- "q4k"
|
|
|
|
|
- "q8k"
|
2026-02-12 12:21:54 -05:00
|
|
|
|
|
|
|
|
# Batched pure prefill
|
|
|
|
|
- "2q512"
|
|
|
|
|
- "2q1k"
|
|
|
|
|
- "2q2k"
|
|
|
|
|
- "2q4k"
|
|
|
|
|
- "2q8k"
|
|
|
|
|
- "4q512"
|
|
|
|
|
- "4q1k"
|
|
|
|
|
- "4q2k"
|
|
|
|
|
- "4q4k"
|
|
|
|
|
- "4q8k"
|
|
|
|
|
- "8q512"
|
|
|
|
|
- "8q1k"
|
|
|
|
|
- "8q2k"
|
|
|
|
|
- "8q4k"
|
|
|
|
|
- "8q8k"
|
|
|
|
|
|
2026-03-12 12:10:17 -04:00
|
|
|
# Chunked prefill / extend
|
|
|
|
|
# Short context
|
|
|
|
|
- "q128s1k"
|
|
|
|
|
- "q256s2k"
|
|
|
|
|
- "q512s4k"
|
|
|
|
|
- "q1ks4k"
|
|
|
|
|
- "q2ks8k"
|
|
|
|
|
- "2q128s1k"
|
|
|
|
|
- "2q256s2k"
|
|
|
|
|
- "2q512s4k"
|
|
|
|
|
- "2q1ks4k"
|
|
|
|
|
- "2q2ks8k"
|
|
|
|
|
- "4q128s1k"
|
|
|
|
|
- "4q256s2k"
|
|
|
|
|
- "4q512s4k"
|
|
|
|
|
- "4q1ks4k"
|
|
|
|
|
- "4q2ks8k"
|
|
|
|
|
- "8q128s1k"
|
|
|
|
|
- "8q256s2k"
|
|
|
|
|
- "8q512s4k"
|
|
|
|
|
- "8q1ks4k"
|
|
|
|
|
|
|
|
|
|
# Medium context
|
|
|
|
|
- "q128s16k"
|
|
|
|
|
- "q512s16k"
|
|
|
|
|
- "q1ks16k"
|
|
|
|
|
- "q2ks16k"
|
|
|
|
|
- "2q128s16k"
|
|
|
|
|
- "2q512s16k"
|
|
|
|
|
- "2q1ks16k"
|
|
|
|
|
- "2q2ks16k"
|
|
|
|
|
- "4q128s16k"
|
|
|
|
|
- "4q512s16k"
|
|
|
|
|
- "4q1ks16k"
|
|
|
|
|
- "4q2ks16k"
|
|
|
|
|
|
|
|
|
|
# Long context
|
|
|
|
|
- "q128s64k"
|
|
|
|
|
- "q512s64k"
|
|
|
|
|
- "q1ks64k"
|
|
|
|
|
- "q2ks64k"
|
|
|
|
|
- "2q128s64k"
|
|
|
|
|
- "2q512s64k"
|
|
|
|
|
- "2q1ks64k"
|
|
|
|
|
- "2q2ks64k"
|
|
|
|
|
|
|
|
|
|
decode_backends:
|
|
|
|
|
- CUTLASS_MLA
|
2026-02-12 12:21:54 -05:00
|
|
|
|
2026-03-12 12:10:17 -04:00
|
|
|
prefill_backends:
|
|
|
|
|
- fa2
|
|
|
|
|
- fa3
|
|
|
|
|
- fa4
|
|
|
|
|
- flashinfer
|
|
|
|
|
- cudnn
|
|
|
|
|
- trtllm
|
2026-02-12 12:21:54 -05:00
|
|
|
|
|
|
|
|
device: "cuda:0"
|
2026-03-12 12:10:17 -04:00
|
|
|
repeats: 20
|
|
|
|
|
warmup_iters: 5
|