[Kernel] Flash Attention 3 Support (#12093)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -199,11 +199,11 @@ class GPUModelRunner:
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||
self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
@@ -412,11 +412,10 @@ class GPUModelRunner:
|
||||
np.cumsum(num_scheduled_tokens,
|
||||
out=self.query_start_loc_np[1:num_reqs + 1])
|
||||
|
||||
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
max_seq_len = seq_lens.max()
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
max_seq_len = self.seq_lens_np[:num_reqs].max()
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
@@ -433,8 +432,8 @@ class GPUModelRunner:
|
||||
non_blocking=True)
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
|
||||
non_blocking=True)
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True).long()
|
||||
|
||||
@@ -506,33 +505,30 @@ class GPUModelRunner:
|
||||
[0, total_num_scheduled_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cu_suffix_kv_lens = (
|
||||
self.seq_start_loc_np[:num_reqs + 1] -
|
||||
self.arange_np[:num_reqs + 1] * common_prefix_len)
|
||||
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
|
||||
self.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_start_loc=seq_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
block_table=(
|
||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
|
||||
Reference in New Issue
Block a user