[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel (#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-08-05 17:45:34 +08:00
committed by GitHub
parent 4771df7b2b
commit 83156c7b89
9 changed files with 700 additions and 234 deletions

View File

@@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import (_get_range_buf, get_seq_lens,
trtllm_batch_decode_with_kv_cache)
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.utils.flashinfer import use_trtllm_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block
# yapf: disable
@@ -149,9 +150,12 @@ class FlashInferMetadata:
slot_mapping: torch.Tensor
# For flashinfer trtllm batch decode
max_q_len: int
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
prefill_use_trtllm: bool
decode_use_trtllm: bool
# For handling prefill decode split
num_decodes: int
@@ -170,6 +174,9 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
qo_indptr_gpu: Optional[torch.Tensor] = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
def __post_init__(self):
if self.head_dim is not None:
FlashInferBackend.validate_head_size(self.head_dim)
@@ -305,8 +312,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper
def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata: FlashInferMetadata):
def _plan(self, attn_metadata: FlashInferMetadata):
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
@@ -341,6 +347,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = num_decodes
@@ -356,23 +364,31 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# to be relative to the start of the prefill queries.
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.kv_data_type,
)
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
prefill_start:]
if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
attn_metadata.paged_kv_indices,
attn_metadata.
paged_kv_last_page_len_cpu[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.kv_data_type,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
self.device)
if num_decodes > 0:
pure_decode = num_prefills == 0
@@ -400,11 +416,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
if not use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
if not attn_metadata.decode_use_trtllm:
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
@@ -437,6 +449,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
split_decodes_and_prefills(common_attn_metadata)
page_size = self.kv_cache_spec.block_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
@@ -503,6 +516,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cache_dtype)
else:
kv_cache_dtype = self.kv_cache_spec.dtype
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config)
num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache
# trtllm may not support sliding window
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
and not cache_dtype.startswith("fp8")
and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
and use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
@@ -510,14 +541,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=self.
paged_kv_last_page_len_cpu[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
kv_data_type=kv_cache_dtype,
q_data_type=self.vllm_config.model_config.dtype,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
prefill_use_trtllm=prefill_use_trtllm,
decode_use_trtllm=decode_use_trtllm,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
@@ -527,12 +563,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
)
self._plan(num_prefills, num_decodes, attn_metadata)
self._plan(attn_metadata)
return attn_metadata
@@ -698,30 +731,64 @@ class FlashInferImpl(AttentionImpl):
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if prefill_wrapper := attn_metadata.prefill_wrapper:
if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
if decode_wrapper := attn_metadata.decode_wrapper:
if not attn_metadata.prefill_use_trtllm:
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left
assert prefill_wrapper._logits_soft_cap == (
self.logits_soft_cap or 0.0)
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
else:
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
workspace_buffer = prefill_wrapper._float_workspace_buffer
block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert prefill_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert workspace_buffer.is_contiguous()
assert block_tables_prefill.is_contiguous()
assert seq_lens_prefill.is_contiguous()
trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
workspace_buffer=workspace_buffer,
block_tables=block_tables_prefill,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
out=output[num_decode_tokens:],
)
if num_decode_tokens > 0:
decode_wrapper = attn_metadata.decode_wrapper
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
if not attn_metadata.decode_use_trtllm:
assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
@@ -734,34 +801,32 @@ class FlashInferImpl(AttentionImpl):
out=output[:num_decode_tokens],
)
else:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
if num_decode_tokens > 0:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]
workspace_buffer = decode_wrapper._float_workspace_buffer
assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert workspace_buffer.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
assert workspace_buffer.is_contiguous()
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
return output_padded
@@ -786,8 +851,8 @@ def fast_plan_decode(
non_blocking: bool = True,
) -> None:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
cudagraph capture/replay, while the no cudagraph version turns back
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
cudagraph capture/replay, while the no cudagraph version turns back
to the original plan.
using original plan after passing host-side buffers:
- only host-to-device copy of indptr and last_page_len buffers