[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend (#21525)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -194,7 +194,6 @@ class FlashInferMetadata:
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table_tensor: torch.Tensor
|
||||
workspace_buffer: torch.Tensor
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
@@ -473,7 +472,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
workspace_buffer=self._get_workspace_buffer(),
|
||||
)
|
||||
|
||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||
@@ -641,11 +639,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
@@ -666,22 +664,24 @@ class FlashInferImpl(AttentionImpl):
|
||||
num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:
|
||||
num_decode_tokens]
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert decode_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
|
||||
output[:num_decode_tokens] = (
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=attn_metadata.workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
))
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
return output_padded
|
||||
|
||||
Reference in New Issue
Block a user