Optimize input preparation for FlashInfer [2/N] (#23174)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional, Union
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
@@ -22,6 +23,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv, is_pin_memory_available
|
from vllm.utils import cdiv, is_pin_memory_available
|
||||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||||
use_trtllm_attention)
|
use_trtllm_attention)
|
||||||
@@ -230,6 +232,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
|
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
||||||
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@@ -238,10 +241,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
|
self.paged_kv_last_page_len_np = (
|
||||||
self.block_table_arange = torch.arange(max_num_pages_per_req,
|
self.paged_kv_last_page_len_cpu.numpy())
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
@@ -317,9 +318,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
max_seq_len = common_attn_metadata.max_seq_len
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
|
seq_lens_np = seq_lens_cpu.numpy()
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|
||||||
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
|
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
@@ -342,37 +344,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
# Remove the blocks of the shared prefix from all requests.
|
# Remove the blocks of the shared prefix from all requests.
|
||||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||||
block_table_bounds_cpu -= num_common_kv_blocks
|
num_blocks_np -= num_common_kv_blocks
|
||||||
else:
|
else:
|
||||||
shared_qo_indptr_cpu = None
|
shared_qo_indptr_cpu = None
|
||||||
shared_kv_page_indptr_cpu = None
|
shared_kv_page_indptr_cpu = None
|
||||||
shared_kv_page_indices_cpu = None
|
shared_kv_page_indices_cpu = None
|
||||||
shared_kv_last_page_len_cpu = None
|
shared_kv_last_page_len_cpu = None
|
||||||
|
|
||||||
max_num_blocks = block_table_bounds_cpu.max().item()
|
|
||||||
block_table_bounds = block_table_bounds_cpu.to(self.device,
|
|
||||||
non_blocking=True)
|
|
||||||
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
|
||||||
< block_table_bounds.unsqueeze(1))
|
|
||||||
# write self.paged_kv_indices inplace
|
|
||||||
num_actual_pages = torch.sum(mask)
|
|
||||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
|
||||||
torch.masked_select(block_table_tensor[:, :max_num_blocks],
|
|
||||||
mask,
|
|
||||||
out=paged_kv_indices)
|
|
||||||
|
|
||||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||||
torch.cumsum(block_table_bounds_cpu,
|
np.cumsum(
|
||||||
dim=0,
|
num_blocks_np,
|
||||||
dtype=torch.int32,
|
dtype=np.int32,
|
||||||
out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
|
out=self.paged_kv_indptr_np[1:num_reqs + 1],
|
||||||
|
)
|
||||||
|
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
|
||||||
|
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
|
# write self.paged_kv_indices inplace
|
||||||
|
num_actual_pages = num_blocks_np.sum().item()
|
||||||
|
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||||
|
_copy_page_indices_kernel[(num_reqs, )](
|
||||||
|
paged_kv_indices,
|
||||||
|
block_table_tensor,
|
||||||
|
block_table_tensor.stride(0),
|
||||||
|
paged_kv_indptr,
|
||||||
|
BLOCK_SIZE=1024,
|
||||||
|
)
|
||||||
|
|
||||||
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
|
|
||||||
# write self.paged_kv_last_page_len_cpu inplace
|
# write self.paged_kv_last_page_len_cpu inplace
|
||||||
torch.where(paged_kv_last_page_len_cpu == 0,
|
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||||
torch.tensor(page_size),
|
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
|
||||||
paged_kv_last_page_len_cpu,
|
paged_kv_last_page_len_np == 0,
|
||||||
out=self.paged_kv_last_page_len_cpu[:num_reqs])
|
page_size,
|
||||||
|
paged_kv_last_page_len_np,
|
||||||
|
)
|
||||||
|
|
||||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||||
has_sinks = self.global_hyperparameters.has_sinks
|
has_sinks = self.global_hyperparameters.has_sinks
|
||||||
@@ -1002,3 +1008,25 @@ def fast_plan_decode(
|
|||||||
self._sm_scale = sm_scale
|
self._sm_scale = sm_scale
|
||||||
self._rope_scale = rope_scale
|
self._rope_scale = rope_scale
|
||||||
self._rope_theta = rope_theta
|
self._rope_theta = rope_theta
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _copy_page_indices_kernel(
|
||||||
|
page_indices,
|
||||||
|
block_table,
|
||||||
|
block_table_stride,
|
||||||
|
cu_num_blocks,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
row_ptr = block_table + req_idx * block_table_stride
|
||||||
|
start_idx = tl.load(cu_num_blocks + req_idx)
|
||||||
|
end_idx = tl.load(cu_num_blocks + req_idx + 1)
|
||||||
|
num_blocks = end_idx - start_idx
|
||||||
|
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||||
|
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
|
||||||
|
tl.store(page_indices + start_idx + i + offset,
|
||||||
|
block_ids,
|
||||||
|
mask=i + offset < num_blocks)
|
||||||
|
|||||||
Reference in New Issue
Block a user