[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend (#21525)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-07-29 22:34:00 +08:00
committed by GitHub
parent ad341c5194
commit 58b11b24a6
4 changed files with 60 additions and 41 deletions

View File

@@ -113,27 +113,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = wrapper.run(query, key_value_cache, scale)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query, key_value_cache, scale, out=output)
# TRTLLM Decode
max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens,
dtype=torch.int,
device=query.device)
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query.contiguous(),
key_value_cache,
workspace_buffer,
num_query_heads,
num_kv_heads,
scale,
block_tables,
kv_lens_tensor,
block_size,
max_kv_len,
"auto",
k_scale,
v_scale,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \