[Misc] Simplify FlashInfer attention metadata (#23585)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
@@ -123,29 +123,9 @@ class FlashInferMetadata:
|
|||||||
|
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
|
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
|
||||||
# is [4, 6], it is [0, 4, 10].
|
|
||||||
qo_indptr_cpu: torch.Tensor
|
|
||||||
# An example for paged_kv_indices, paged_kv_indptr:
|
|
||||||
# request 1, page indices [0, 5, 8]
|
|
||||||
# request 2, page indices [1, 6, 7]
|
|
||||||
# request 3, page indices [3, 4]
|
|
||||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
|
||||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
|
||||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
|
||||||
# [0, 3, 6, 8]
|
|
||||||
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
|
|
||||||
paged_kv_indptr_cpu: torch.Tensor
|
|
||||||
# The page indices of the paged kv cache (on device for plan)
|
|
||||||
paged_kv_indices: torch.Tensor
|
|
||||||
# The number of entries in the last page of each request in
|
|
||||||
# the paged kv cache, shape: [batch_size] (CPU for plan)
|
|
||||||
paged_kv_last_page_len_cpu: torch.Tensor
|
|
||||||
# The data type of the query
|
# The data type of the query
|
||||||
q_data_type: torch.dtype
|
q_data_type: torch.dtype
|
||||||
|
|
||||||
seq_lens_cpu: torch.Tensor
|
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
# For flashinfer trtllm batch decode
|
# For flashinfer trtllm batch decode
|
||||||
@@ -164,10 +144,6 @@ class FlashInferMetadata:
|
|||||||
|
|
||||||
# For cascade attention (CPU for planning).
|
# For cascade attention (CPU for planning).
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
|
|
||||||
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
|
|
||||||
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
|
|
||||||
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||||
@@ -327,134 +303,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||||
return self._cascade_wrapper
|
return self._cascade_wrapper
|
||||||
|
|
||||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
|
||||||
if attn_metadata.use_cascade:
|
|
||||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
|
||||||
attn_metadata.cascade_wrapper.plan(
|
|
||||||
[
|
|
||||||
attn_metadata.shared_qo_indptr_cpu,
|
|
||||||
attn_metadata.qo_indptr_cpu
|
|
||||||
],
|
|
||||||
[
|
|
||||||
attn_metadata.shared_kv_page_indptr_cpu,
|
|
||||||
attn_metadata.paged_kv_indptr_cpu
|
|
||||||
],
|
|
||||||
[
|
|
||||||
attn_metadata.shared_kv_page_indices_cpu,
|
|
||||||
attn_metadata.paged_kv_indices
|
|
||||||
],
|
|
||||||
[
|
|
||||||
attn_metadata.shared_kv_last_page_len_cpu,
|
|
||||||
attn_metadata.paged_kv_last_page_len_cpu
|
|
||||||
],
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.page_size,
|
|
||||||
causal=True,
|
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
|
||||||
window_left=self.global_hyperparameters.window_left,
|
|
||||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
kv_data_type=self.kv_cache_dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Regular attention (common case).
|
|
||||||
# Decodes are at the front and prefills are at the back,
|
|
||||||
# according to reorder_batch()
|
|
||||||
num_prefills = attn_metadata.num_prefills
|
|
||||||
num_decodes = attn_metadata.num_decodes
|
|
||||||
if num_prefills > 0:
|
|
||||||
# Decodes are first so prefills start after the last decode
|
|
||||||
prefill_start = num_decodes
|
|
||||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
|
||||||
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
|
|
||||||
0] == num_prefills + 1
|
|
||||||
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
|
|
||||||
0] == num_prefills + 1
|
|
||||||
assert attn_metadata.paged_kv_last_page_len_cpu[
|
|
||||||
prefill_start:].shape[0] == num_prefills
|
|
||||||
# Since prefill_wrapper.run() will be called with
|
|
||||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
|
||||||
# to be relative to the start of the prefill queries.
|
|
||||||
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
|
||||||
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
|
||||||
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
|
|
||||||
prefill_start:]
|
|
||||||
if not attn_metadata.prefill_use_trtllm:
|
|
||||||
attn_metadata.prefill_wrapper.plan(
|
|
||||||
qo_indptr_cpu,
|
|
||||||
paged_kv_indptr_cpu,
|
|
||||||
attn_metadata.paged_kv_indices,
|
|
||||||
attn_metadata.
|
|
||||||
paged_kv_last_page_len_cpu[prefill_start:],
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.page_size,
|
|
||||||
causal=True,
|
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
|
||||||
window_left=self.global_hyperparameters.window_left,
|
|
||||||
logits_soft_cap=self.global_hyperparameters.
|
|
||||||
logits_soft_cap,
|
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
kv_data_type=self.kv_cache_dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
|
||||||
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
|
||||||
self.device)
|
|
||||||
|
|
||||||
if num_decodes > 0:
|
|
||||||
pure_decode = num_prefills == 0
|
|
||||||
# possible required padding for cudagraph replay
|
|
||||||
use_cudagraph = (self.enable_cuda_graph and pure_decode and
|
|
||||||
num_decodes <= self._decode_cudagraph_max_bs)
|
|
||||||
if use_cudagraph:
|
|
||||||
num_input_tokens = (
|
|
||||||
self.vllm_config.pad_for_cudagraph(num_decodes))
|
|
||||||
# Carefully fulfill the padding region with reasonable value
|
|
||||||
# on cpu.
|
|
||||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
|
||||||
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
|
|
||||||
num_input_tokens].fill_(
|
|
||||||
attn_metadata.
|
|
||||||
paged_kv_indptr_cpu[-1])
|
|
||||||
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
|
||||||
# This is because flashinfer treats 0 as a full page
|
|
||||||
# instead of empty.
|
|
||||||
self.paged_kv_last_page_len_cpu[
|
|
||||||
num_decodes:num_input_tokens].fill_(1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
num_input_tokens = num_decodes
|
|
||||||
|
|
||||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
|
||||||
num_input_tokens, use_cudagraph)
|
|
||||||
if not attn_metadata.decode_use_trtllm:
|
|
||||||
# Use the persistent buffer with padding length,
|
|
||||||
# instead of the same address but chunked version
|
|
||||||
# in atten_metadata when using cudagraph.
|
|
||||||
fast_plan_decode(
|
|
||||||
attn_metadata.decode_wrapper,
|
|
||||||
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
|
||||||
attn_metadata.paged_kv_indices,
|
|
||||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
|
||||||
attn_metadata.seq_lens_cpu[:num_input_tokens],
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.page_size,
|
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
|
||||||
pos_encoding_mode="NONE",
|
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
|
||||||
window_left=self.global_hyperparameters.window_left,
|
|
||||||
logits_soft_cap=self.global_hyperparameters.
|
|
||||||
logits_soft_cap,
|
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
kv_data_type=self.kv_cache_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@@ -548,13 +396,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
|
||||||
paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
|
|
||||||
paged_kv_indices=paged_kv_indices,
|
|
||||||
paged_kv_last_page_len_cpu=self.
|
|
||||||
paged_kv_last_page_len_cpu[:num_reqs],
|
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
max_q_len=max_q_len,
|
max_q_len=max_q_len,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
@@ -567,14 +409,123 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
|
|
||||||
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
|
||||||
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
|
||||||
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._plan(attn_metadata)
|
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
|
||||||
|
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
|
||||||
|
|
||||||
|
if attn_metadata.use_cascade:
|
||||||
|
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||||
|
attn_metadata.cascade_wrapper.plan(
|
||||||
|
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||||
|
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||||
|
[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||||
|
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.page_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=self.global_hyperparameters.sm_scale,
|
||||||
|
window_left=self.global_hyperparameters.window_left,
|
||||||
|
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
kv_data_type=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Regular attention (common case).
|
||||||
|
# Decodes are at the front and prefills are at the back,
|
||||||
|
# according to reorder_batch()
|
||||||
|
num_prefills = attn_metadata.num_prefills
|
||||||
|
num_decodes = attn_metadata.num_decodes
|
||||||
|
if num_prefills > 0:
|
||||||
|
# Decodes are first so prefills start after the last decode
|
||||||
|
prefill_start = num_decodes
|
||||||
|
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||||
|
assert qo_indptr_cpu[prefill_start:].shape[
|
||||||
|
0] == num_prefills + 1
|
||||||
|
assert paged_kv_indptr_cpu[prefill_start:].shape[
|
||||||
|
0] == num_prefills + 1
|
||||||
|
assert paged_kv_last_page_len_cpu[prefill_start:].shape[
|
||||||
|
0] == num_prefills
|
||||||
|
# Since prefill_wrapper.run() will be called with
|
||||||
|
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||||
|
# to be relative to the start of the prefill queries.
|
||||||
|
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
|
||||||
|
prefill_start]
|
||||||
|
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
|
||||||
|
if not attn_metadata.prefill_use_trtllm:
|
||||||
|
attn_metadata.prefill_wrapper.plan(
|
||||||
|
qo_indptr_cpu,
|
||||||
|
paged_kv_indptr_cpu,
|
||||||
|
paged_kv_indices,
|
||||||
|
paged_kv_last_page_len_cpu[prefill_start:],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.page_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=self.global_hyperparameters.sm_scale,
|
||||||
|
window_left=self.global_hyperparameters.window_left,
|
||||||
|
logits_soft_cap=self.global_hyperparameters.
|
||||||
|
logits_soft_cap,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
kv_data_type=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
||||||
|
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
|
if num_decodes > 0:
|
||||||
|
pure_decode = num_prefills == 0
|
||||||
|
# possible required padding for cudagraph replay
|
||||||
|
use_cudagraph = (self.enable_cuda_graph and pure_decode and
|
||||||
|
num_decodes <= self._decode_cudagraph_max_bs)
|
||||||
|
if use_cudagraph:
|
||||||
|
num_input_tokens = (
|
||||||
|
self.vllm_config.pad_for_cudagraph(num_decodes))
|
||||||
|
# Carefully fulfill the padding region with reasonable value
|
||||||
|
# on cpu.
|
||||||
|
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||||
|
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
|
||||||
|
num_input_tokens].fill_(
|
||||||
|
paged_kv_indptr_cpu[-1])
|
||||||
|
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
||||||
|
# This is because flashinfer treats 0 as a full page
|
||||||
|
# instead of empty.
|
||||||
|
self.paged_kv_last_page_len_cpu[
|
||||||
|
num_decodes:num_input_tokens].fill_(1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
num_input_tokens = num_decodes
|
||||||
|
|
||||||
|
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||||
|
num_input_tokens, use_cudagraph)
|
||||||
|
if not attn_metadata.decode_use_trtllm:
|
||||||
|
# Use the persistent buffer with padding length,
|
||||||
|
# instead of the same address but chunked version
|
||||||
|
# in atten_metadata when using cudagraph.
|
||||||
|
fast_plan_decode(
|
||||||
|
attn_metadata.decode_wrapper,
|
||||||
|
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
||||||
|
paged_kv_indices,
|
||||||
|
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||||
|
seq_lens_cpu[:num_input_tokens],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.page_size,
|
||||||
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
|
pos_encoding_mode="NONE",
|
||||||
|
sm_scale=self.global_hyperparameters.sm_scale,
|
||||||
|
window_left=self.global_hyperparameters.window_left,
|
||||||
|
logits_soft_cap=self.global_hyperparameters.
|
||||||
|
logits_soft_cap,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
kv_data_type=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
|
|||||||
Reference in New Issue
Block a user