[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend (#21525)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-07-29 22:34:00 +08:00
committed by GitHub
parent ad341c5194
commit 58b11b24a6
4 changed files with 60 additions and 41 deletions

View File

@@ -71,22 +71,20 @@ def benchmark_decode(
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)
output_trtllm = torch.empty(q.shape, dtype=dtype)
# Benchmark TRT decode
def trt_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q,
kv_cache,
workspace_buffer,
num_qo_heads,
num_kv_heads,
sm_scale,
block_tables,
kv_lens_tensor,
page_size,
max_kv_len,
kv_cache_dtype,
k_scale,
v_scale,
bmm1_scale=k_scale * sm_scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
def time_fn(fn, warmup=10, trials=20):
@@ -125,6 +123,8 @@ def benchmark_decode(
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
output_baseline = torch.empty(q.shape, dtype=dtype)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
@@ -145,7 +145,7 @@ def benchmark_decode(
)
def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
baseline_mean, baseline_std = time_fn(baseline_decode)
@@ -214,25 +214,39 @@ if __name__ == "__main__":
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
print("Running benchmark for kv_cache_dtype: bfloat16")
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="fp8",
)
all_results.append(result)