[Attention] Add FlashInfer Sparse MLA backend (#33451)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from common import (
|
||||
ModelParameterSweep,
|
||||
ParameterSweep,
|
||||
ResultsFormatter,
|
||||
batch_spec_sort_key,
|
||||
is_mla_backend,
|
||||
)
|
||||
|
||||
@@ -218,10 +219,13 @@ def run_model_parameter_sweep(
|
||||
by_param_and_spec[key].append(r)
|
||||
break
|
||||
|
||||
# Sort by param value then spec
|
||||
# Sort by param value then spec (batch_size, q_len, kv_len)
|
||||
sorted_keys = sorted(
|
||||
by_param_and_spec.keys(),
|
||||
key=lambda x: (int(x[0]) if x[0].isdigit() else x[0], x[1]),
|
||||
key=lambda x: (
|
||||
int(x[0]) if x[0].isdigit() else x[0],
|
||||
batch_spec_sort_key(x[1]),
|
||||
),
|
||||
)
|
||||
|
||||
current_param_value = None
|
||||
@@ -330,7 +334,7 @@ def run_parameter_sweep(
|
||||
by_spec[spec] = []
|
||||
by_spec[spec].append(r)
|
||||
|
||||
for spec in sorted(by_spec.keys()):
|
||||
for spec in sorted(by_spec.keys(), key=batch_spec_sort_key):
|
||||
results = by_spec[spec]
|
||||
best = min(results, key=lambda r: r.mean_time)
|
||||
console.print(
|
||||
@@ -496,15 +500,18 @@ def main():
|
||||
if "description" in yaml_config:
|
||||
console.print(f"[dim]{yaml_config['description']}[/]")
|
||||
|
||||
# Override args with YAML values
|
||||
# (YAML takes precedence unless CLI arg was explicitly set)
|
||||
# Backend(s)
|
||||
if "backend" in yaml_config:
|
||||
args.backend = yaml_config["backend"]
|
||||
args.backends = None
|
||||
elif "backends" in yaml_config:
|
||||
args.backends = yaml_config["backends"]
|
||||
args.backend = None
|
||||
# Override args with YAML values, but CLI args take precedence
|
||||
# Check if CLI provided backends (they would be non-None and not default)
|
||||
cli_backends_provided = args.backends is not None or args.backend is not None
|
||||
|
||||
# Backend(s) - only use YAML if CLI didn't specify
|
||||
if not cli_backends_provided:
|
||||
if "backend" in yaml_config:
|
||||
args.backend = yaml_config["backend"]
|
||||
args.backends = None
|
||||
elif "backends" in yaml_config:
|
||||
args.backends = yaml_config["backends"]
|
||||
args.backend = None
|
||||
|
||||
# Check for special modes
|
||||
if "mode" in yaml_config:
|
||||
@@ -544,13 +551,15 @@ def main():
|
||||
args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads)
|
||||
args.block_size = model.get("block_size", args.block_size)
|
||||
|
||||
# Benchmark settings
|
||||
if "benchmark" in yaml_config:
|
||||
bench = yaml_config["benchmark"]
|
||||
args.device = bench.get("device", args.device)
|
||||
args.repeats = bench.get("repeats", args.repeats)
|
||||
args.warmup_iters = bench.get("warmup_iters", args.warmup_iters)
|
||||
args.profile_memory = bench.get("profile_memory", args.profile_memory)
|
||||
# Benchmark settings (top-level keys)
|
||||
if "device" in yaml_config:
|
||||
args.device = yaml_config["device"]
|
||||
if "repeats" in yaml_config:
|
||||
args.repeats = yaml_config["repeats"]
|
||||
if "warmup_iters" in yaml_config:
|
||||
args.warmup_iters = yaml_config["warmup_iters"]
|
||||
if "profile_memory" in yaml_config:
|
||||
args.profile_memory = yaml_config["profile_memory"]
|
||||
|
||||
# Parameter sweep configuration
|
||||
if "parameter_sweep" in yaml_config:
|
||||
|
||||
Reference in New Issue
Block a user