[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend (#21525)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-07-29 22:34:00 +08:00
committed by GitHub
parent ad341c5194
commit 58b11b24a6
4 changed files with 60 additions and 41 deletions

View File

@@ -194,7 +194,6 @@ class FlashInferMetadata:
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
workspace_buffer: torch.Tensor
# For handling prefill decode split
num_decodes: int
@@ -473,7 +472,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
workspace_buffer=self._get_workspace_buffer(),
)
self._plan(num_prefills, num_decodes, attn_metadata)
@@ -641,11 +639,11 @@ class FlashInferImpl(AttentionImpl):
if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not FlashInferBackend.use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
assert decode_wrapper is not None
assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
@@ -666,22 +664,24 @@ class FlashInferImpl(AttentionImpl):
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]
workspace_buffer = decode_wrapper._float_workspace_buffer
assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
assert workspace_buffer.is_contiguous()
output[:num_decode_tokens] = (
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
))
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
return output_padded