[Kernel] Flash Attention 3 Support (#12093)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-01-23 09:45:48 -05:00
committed by GitHub
parent c5b4b11d7f
commit 978b45f399
8 changed files with 151 additions and 83 deletions

View File

@@ -199,11 +199,11 @@ class GPUModelRunner:
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
@@ -412,11 +412,10 @@ class GPUModelRunner:
np.cumsum(num_scheduled_tokens,
out=self.query_start_loc_np[1:num_reqs + 1])
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1])
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = self.seq_lens_np[:num_reqs].max()
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
@@ -433,8 +432,8 @@ class GPUModelRunner:
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
@@ -506,33 +505,30 @@ class GPUModelRunner:
[0, total_num_scheduled_tokens],
dtype=torch.int32,
device=self.device)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
dtype=torch.int32,
device=self.device)
cu_suffix_kv_lens = (
self.seq_start_loc_np[:num_reqs + 1] -
self.arange_np[:num_reqs + 1] * common_prefix_len)
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
else:
cu_prefix_query_lens = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
seq_lens=seq_lens,
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this