Optimize input preparation for FlashInfer [2/N] (#23174)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-27 02:52:45 -07:00
committed by GitHub
parent 5bd9f84158
commit 6578e87365

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional, Union from typing import ClassVar, Optional, Union
import numpy as np
import torch import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
@@ -22,6 +23,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant) QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention, from vllm.utils.flashinfer import (supports_trtllm_attention,
use_trtllm_attention) use_trtllm_attention)
@@ -230,6 +232,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indices_cpu = torch.zeros(max_num_pages, self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
@@ -238,10 +241,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.paged_kv_last_page_len_np = (
self.block_table_arange = torch.arange(max_num_pages_per_req, self.paged_kv_last_page_len_cpu.numpy())
dtype=torch.int32,
device=self.device)
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
@@ -317,9 +318,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len = common_attn_metadata.max_seq_len max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
use_cascade = common_prefix_len > 0 use_cascade = common_prefix_len > 0
if use_cascade: if use_cascade:
@@ -342,37 +344,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Remove the blocks of the shared prefix from all requests. # Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds_cpu -= num_common_kv_blocks num_blocks_np -= num_common_kv_blocks
else: else:
shared_qo_indptr_cpu = None shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None shared_kv_last_page_len_cpu = None
max_num_blocks = block_table_bounds_cpu.max().item()
block_table_bounds = block_table_bounds_cpu.to(self.device,
non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1))
# write self.paged_kv_indices inplace
num_actual_pages = torch.sum(mask)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks],
mask,
out=paged_kv_indices)
# write self.paged_kv_indptr_cpu inplace (0-index is always 0) # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
torch.cumsum(block_table_bounds_cpu, np.cumsum(
dim=0, num_blocks_np,
dtype=torch.int32, dtype=np.int32,
out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) out=self.paged_kv_indptr_np[1:num_reqs + 1],
)
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
non_blocking=True)
# write self.paged_kv_indices inplace
num_actual_pages = num_blocks_np.sum().item()
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs, )](
paged_kv_indices,
block_table_tensor,
block_table_tensor.stride(0),
paged_kv_indptr,
BLOCK_SIZE=1024,
)
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
# write self.paged_kv_last_page_len_cpu inplace # write self.paged_kv_last_page_len_cpu inplace
torch.where(paged_kv_last_page_len_cpu == 0, paged_kv_last_page_len_np = seq_lens_np % page_size
torch.tensor(page_size), self.paged_kv_last_page_len_np[:num_reqs] = np.where(
paged_kv_last_page_len_cpu, paged_kv_last_page_len_np == 0,
out=self.paged_kv_last_page_len_cpu[:num_reqs]) page_size,
paged_kv_last_page_len_np,
)
# Check if any layer uses sinks (requires TRTLLM attention) # Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks has_sinks = self.global_hyperparameters.has_sinks
@@ -1002,3 +1008,25 @@ def fast_plan_decode(
self._sm_scale = sm_scale self._sm_scale = sm_scale
self._rope_scale = rope_scale self._rope_scale = rope_scale
self._rope_theta = rope_theta self._rope_theta = rope_theta
@triton.jit
def _copy_page_indices_kernel(
page_indices,
block_table,
block_table_stride,
cu_num_blocks,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = block_table + req_idx * block_table_stride
start_idx = tl.load(cu_num_blocks + req_idx)
end_idx = tl.load(cu_num_blocks + req_idx + 1)
num_blocks = end_idx - start_idx
offset = tl.arange(0, BLOCK_SIZE)
for i in tl.range(0, num_blocks, BLOCK_SIZE):
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
tl.store(page_indices + start_idx + i + offset,
block_ids,
mask=i + offset < num_blocks)