[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel (#22095)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper)
|
||||
from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
||||
trtllm_batch_decode_with_kv_cache)
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
@@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
from vllm.utils.flashinfer import use_trtllm_attention
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -149,9 +150,12 @@ class FlashInferMetadata:
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For flashinfer trtllm batch decode
|
||||
max_q_len: int
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table_tensor: torch.Tensor
|
||||
prefill_use_trtllm: bool
|
||||
decode_use_trtllm: bool
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
@@ -170,6 +174,9 @@ class FlashInferMetadata:
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||
|
||||
qo_indptr_gpu: Optional[torch.Tensor] = None
|
||||
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
@@ -305,8 +312,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _plan(self, num_prefills: int, num_decodes: int,
|
||||
attn_metadata: FlashInferMetadata):
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
@@ -341,6 +347,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
if num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = num_decodes
|
||||
@@ -356,23 +364,31 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
||||
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
|
||||
prefill_start:]
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
else:
|
||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
||||
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
||||
self.device)
|
||||
|
||||
if num_decodes > 0:
|
||||
pure_decode = num_prefills == 0
|
||||
@@ -400,11 +416,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
num_input_tokens, use_cudagraph)
|
||||
if not use_trtllm_decode_attention(
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
self.cache_config.cache_dtype,
|
||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim):
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
@@ -437,6 +449,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
max_q_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
@@ -503,6 +516,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
cache_dtype)
|
||||
else:
|
||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
# trtllm may not support sliding window
|
||||
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||
and not cache_dtype.startswith("fp8")
|
||||
and use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim))
|
||||
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||
and use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim))
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
||||
@@ -510,14 +541,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=self.
|
||||
paged_kv_last_page_len_cpu[:num_reqs],
|
||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config),
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=self.vllm_config.model_config.dtype,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
prefill_use_trtllm=prefill_use_trtllm,
|
||||
decode_use_trtllm=decode_use_trtllm,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
@@ -527,12 +563,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
||||
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
||||
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
)
|
||||
|
||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||
self._plan(attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
@@ -698,30 +731,64 @@ class FlashInferImpl(AttentionImpl):
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if prefill_wrapper := attn_metadata.prefill_wrapper:
|
||||
if num_prefill_tokens > 0:
|
||||
prefill_wrapper = attn_metadata.prefill_wrapper
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
assert prefill_query.shape[0] == num_prefill_tokens
|
||||
assert prefill_wrapper is not None
|
||||
assert prefill_wrapper._causal
|
||||
assert prefill_wrapper._window_left == window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
assert prefill_wrapper._causal
|
||||
assert prefill_wrapper._window_left == window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (
|
||||
self.logits_soft_cap or 0.0)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
else:
|
||||
# prefill_query may be non-contiguous
|
||||
prefill_query = prefill_query.contiguous()
|
||||
workspace_buffer = prefill_wrapper._float_workspace_buffer
|
||||
block_tables_prefill = attn_metadata.block_table_tensor[
|
||||
num_decode_tokens:]
|
||||
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert prefill_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
assert block_tables_prefill.is_contiguous()
|
||||
assert seq_lens_prefill.is_contiguous()
|
||||
|
||||
trtllm_batch_context_with_kv_cache(
|
||||
query=prefill_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_prefill,
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
decode_wrapper = attn_metadata.decode_wrapper
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
if not use_trtllm_decode_attention(
|
||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
@@ -734,34 +801,32 @@ class FlashInferImpl(AttentionImpl):
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
else:
|
||||
# decode_query may be non-contiguous
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||
num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
if num_decode_tokens > 0:
|
||||
# decode_query may be non-contiguous
|
||||
decode_query = decode_query.contiguous()
|
||||
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||
num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:
|
||||
num_decode_tokens]
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert decode_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert decode_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
return output_padded
|
||||
|
||||
|
||||
@@ -786,8 +851,8 @@ def fast_plan_decode(
|
||||
non_blocking: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
|
||||
cudagraph capture/replay, while the no cudagraph version turns back
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
|
||||
cudagraph capture/replay, while the no cudagraph version turns back
|
||||
to the original plan.
|
||||
using original plan after passing host-side buffers:
|
||||
- only host-to-device copy of indptr and last_page_len buffers
|
||||
|
||||
Reference in New Issue
Block a user