[ROCm][Attention] Sliding window support for AiterFlashAttentionBackend (#29234)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -13,8 +13,9 @@ from vllm.attention.backends.abstract import (
|
|||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
@@ -57,58 +58,55 @@ if current_platform.is_rocm():
|
|||||||
head_size,
|
head_size,
|
||||||
x,
|
x,
|
||||||
max_block_num,
|
max_block_num,
|
||||||
num_tokens,
|
|
||||||
num_programs,
|
|
||||||
DEQUANT: tl.constexpr,
|
DEQUANT: tl.constexpr,
|
||||||
PAGE_SIZE: tl.constexpr,
|
PAGE_SIZE: tl.constexpr,
|
||||||
CACHE_FORMAT: tl.constexpr,
|
CACHE_FORMAT: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
bid = tl.program_id(0)
|
token_id = tl.program_id(0)
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
if DEQUANT:
|
if DEQUANT:
|
||||||
k_scale = tl.load(k_scale_ptr)
|
k_scale = tl.load(k_scale_ptr)
|
||||||
v_scale = tl.load(v_scale_ptr)
|
v_scale = tl.load(v_scale_ptr)
|
||||||
|
|
||||||
for token_id in tl.range(bid, num_tokens, num_programs):
|
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
||||||
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
||||||
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||||
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
batch_start = tl.load(seq_start_ptr + batch_idx)
|
||||||
batch_start = tl.load(seq_start_ptr + batch_idx)
|
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
||||||
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
batch_offset = token_id - token_start + batch_start
|
||||||
batch_offset = token_id - token_start + batch_start
|
block_offset = batch_offset // PAGE_SIZE
|
||||||
block_offset = batch_offset // PAGE_SIZE
|
block_id = tl.load(
|
||||||
block_id = tl.load(
|
block_table_ptr + max_block_num * batch_idx + block_offset
|
||||||
block_table_ptr + max_block_num * batch_idx + block_offset
|
).to(tl.int64)
|
||||||
|
slot_id = batch_offset % PAGE_SIZE
|
||||||
|
|
||||||
|
if CACHE_FORMAT == "NHD":
|
||||||
|
# for kv cache layout as
|
||||||
|
# K: [num_blocks, page_size, num_head, head_dim]
|
||||||
|
# V: [num_blocks, page_size, num_head, head_dim]
|
||||||
|
key_cache_ptr_offset = (
|
||||||
|
key_cache_ptr
|
||||||
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
|
+ slot_id * num_heads * head_size
|
||||||
|
)
|
||||||
|
value_cache_ptr_offset = (
|
||||||
|
value_cache_ptr
|
||||||
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
|
+ slot_id * num_heads * head_size
|
||||||
)
|
)
|
||||||
slot_id = batch_offset % PAGE_SIZE
|
|
||||||
|
|
||||||
if CACHE_FORMAT == "NHD":
|
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
||||||
# for kv cache layout as
|
mask = (col_offsets + i) < head_size * num_heads
|
||||||
# K: [num_blocks, page_size, num_head, head_dim]
|
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||||
# V: [num_blocks, page_size, num_head, head_dim]
|
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||||
key_cache_ptr_offset = (
|
if DEQUANT:
|
||||||
key_cache_ptr
|
k_dtype = k_reg.dtype
|
||||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
v_dtype = v_reg.dtype
|
||||||
+ slot_id * num_heads * head_size
|
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||||
)
|
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||||
value_cache_ptr_offset = (
|
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
||||||
value_cache_ptr
|
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
||||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
|
||||||
+ slot_id * num_heads * head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
|
||||||
mask = (col_offsets + i) < head_size * num_heads
|
|
||||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
|
||||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
|
||||||
if DEQUANT:
|
|
||||||
k_dtype = k_reg.dtype
|
|
||||||
v_dtype = v_reg.dtype
|
|
||||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
|
||||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
|
||||||
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
|
||||||
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
|
||||||
|
|
||||||
def cp_mha_gather_cache(
|
def cp_mha_gather_cache(
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
@@ -143,9 +141,7 @@ if current_platform.is_rocm():
|
|||||||
page_size = key_cache.shape[1]
|
page_size = key_cache.shape[1]
|
||||||
num_heads = key_cache.shape[2]
|
num_heads = key_cache.shape[2]
|
||||||
|
|
||||||
NUM_PRGMS = num_programs(total_tokens)
|
grid = lambda meta: (total_tokens,)
|
||||||
BLOCK_SIZE = block_size(key_cache, head_dim)
|
|
||||||
grid = lambda meta: (NUM_PRGMS,)
|
|
||||||
cp_mha_gather_cache_kernel[grid](
|
cp_mha_gather_cache_kernel[grid](
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
@@ -161,12 +157,10 @@ if current_platform.is_rocm():
|
|||||||
head_dim,
|
head_dim,
|
||||||
x,
|
x,
|
||||||
block_tables.size(1),
|
block_tables.size(1),
|
||||||
total_tokens,
|
|
||||||
NUM_PRGMS,
|
|
||||||
DEQUANT=dequant,
|
DEQUANT=dequant,
|
||||||
PAGE_SIZE=page_size,
|
PAGE_SIZE=page_size,
|
||||||
CACHE_FORMAT=kv_cache_layout,
|
CACHE_FORMAT=kv_cache_layout,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -189,6 +183,17 @@ class AiterFlashAttentionPrefillMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AiterChunkSlidingWindowMetadata:
|
||||||
|
swa_seqlens: torch.Tensor
|
||||||
|
swa_cu_seqlens: torch.Tensor
|
||||||
|
swa_seq_starts: torch.Tensor
|
||||||
|
swa_token_to_batch: torch.Tensor
|
||||||
|
swa_max_seqlens: int
|
||||||
|
swa_total_tokens: int
|
||||||
|
swa_workspace: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AiterChunkContextMetadata:
|
class AiterChunkContextMetadata:
|
||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
@@ -200,6 +205,7 @@ class AiterChunkContextMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
num_chunks: int
|
num_chunks: int
|
||||||
total_token_per_batch: list[int]
|
total_token_per_batch: list[int]
|
||||||
|
swa_metadata: AiterChunkSlidingWindowMetadata | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -278,6 +284,20 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
self.aot_sliding_window: tuple[int, int] | None = None
|
self.aot_sliding_window: tuple[int, int] | None = None
|
||||||
self.total_tokens: int = 0
|
self.total_tokens: int = 0
|
||||||
|
|
||||||
|
sliding_window_configs: set[tuple[int, int] | None] = set()
|
||||||
|
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
|
for layer in layers.values():
|
||||||
|
assert isinstance(layer.impl, AiterFlashAttentionImpl)
|
||||||
|
sliding_window_configs.add(layer.impl.sliding_window)
|
||||||
|
|
||||||
|
while len(sliding_window_configs) > 0:
|
||||||
|
sliding_window_config = sliding_window_configs.pop()
|
||||||
|
if sliding_window_config is not None and sliding_window_config[0] != -1:
|
||||||
|
assert self.aot_sliding_window is None, (
|
||||||
|
"Aiter Flash ATTENTION can only support one valid sliding window!"
|
||||||
|
)
|
||||||
|
self.aot_sliding_window = sliding_window_config
|
||||||
|
|
||||||
self.extend_workspace = torch.empty(
|
self.extend_workspace = torch.empty(
|
||||||
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
|
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
@@ -349,6 +369,55 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
query_lens_for_extend = query_lens_cpu[num_extends_slice]
|
query_lens_for_extend = query_lens_cpu[num_extends_slice]
|
||||||
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
|
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
|
||||||
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
|
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
|
||||||
|
swa_metadata = None
|
||||||
|
if self.aot_sliding_window is not None:
|
||||||
|
swa_seqlen_for_extend = torch.minimum(
|
||||||
|
seq_lens_for_extend,
|
||||||
|
query_lens_for_extend + self.aot_sliding_window[0] + 1,
|
||||||
|
)
|
||||||
|
cu_seq_lens = torch.zeros(
|
||||||
|
num_extends + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=seq_lens_for_extend.device,
|
||||||
|
)
|
||||||
|
torch.cumsum(
|
||||||
|
swa_seqlen_for_extend,
|
||||||
|
dim=0,
|
||||||
|
dtype=cu_seq_lens.dtype,
|
||||||
|
out=cu_seq_lens[1:],
|
||||||
|
)
|
||||||
|
token_to_seq = torch.arange(
|
||||||
|
0,
|
||||||
|
num_extends,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=seq_lens_for_extend.device,
|
||||||
|
)
|
||||||
|
token_to_seq = torch.repeat_interleave(
|
||||||
|
token_to_seq, swa_seqlen_for_extend
|
||||||
|
)
|
||||||
|
fetched_shape = cu_seq_lens[-1].item()
|
||||||
|
# TODO(ganyi): Maybe reuse these 2 buffer from extend_workspace
|
||||||
|
swa_workspace = torch.empty(
|
||||||
|
(2, fetched_shape, self.num_heads_kv, self.headdim),
|
||||||
|
dtype=self.vllm_config.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_starts = seq_lens_for_extend - swa_seqlen_for_extend
|
||||||
|
max_seqlen_k = swa_seqlen_for_extend.max().item()
|
||||||
|
total_tokens = cu_seq_lens[-1].item()
|
||||||
|
|
||||||
|
swa_metadata = AiterChunkSlidingWindowMetadata(
|
||||||
|
swa_seqlens=swa_seqlen_for_extend.to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
),
|
||||||
|
swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True),
|
||||||
|
swa_seq_starts=seq_starts.to(self.device, non_blocking=True),
|
||||||
|
swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True),
|
||||||
|
swa_max_seqlens=max_seqlen_k,
|
||||||
|
swa_total_tokens=total_tokens,
|
||||||
|
swa_workspace=swa_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
# allocate the equal amount of workspace for
|
# allocate the equal amount of workspace for
|
||||||
# each chunk prefill request
|
# each chunk prefill request
|
||||||
@@ -392,6 +461,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
|
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
|
||||||
num_chunks=num_chunks,
|
num_chunks=num_chunks,
|
||||||
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
|
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
|
||||||
|
swa_metadata=swa_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
query_start_loc_device = common_attn_metadata.query_start_loc[
|
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||||
@@ -504,9 +574,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
if sliding_window is None:
|
if sliding_window is None:
|
||||||
self.sliding_window = [-1, -1]
|
self.sliding_window = (-1, -1)
|
||||||
else:
|
else:
|
||||||
self.sliding_window = [sliding_window - 1, 0]
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
@@ -522,6 +592,67 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
"Encoder self-attention is not implemented for FlashAttentionImpl"
|
"Encoder self-attention is not implemented for FlashAttentionImpl"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def extend_for_sliding_window(
|
||||||
|
self,
|
||||||
|
attn_metadata: AiterFlashAttentionMetadata,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
output: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
):
|
||||||
|
assert attn_metadata.extend_metadata is not None
|
||||||
|
assert attn_metadata.extend_metadata.chunk_context_metadata is not None
|
||||||
|
chunked_metadata = attn_metadata.extend_metadata.chunk_context_metadata
|
||||||
|
swa_metadata = chunked_metadata.swa_metadata
|
||||||
|
assert swa_metadata is not None
|
||||||
|
swa_cu_seqlens = swa_metadata.swa_cu_seqlens
|
||||||
|
swa_seq_starts = swa_metadata.swa_seq_starts
|
||||||
|
swa_token_to_batch = swa_metadata.swa_token_to_batch
|
||||||
|
swa_max_seqlens = swa_metadata.swa_max_seqlens
|
||||||
|
swa_total_tokens = swa_metadata.swa_total_tokens
|
||||||
|
key_fetched, value_fetched = (
|
||||||
|
swa_metadata.swa_workspace[0],
|
||||||
|
swa_metadata.swa_workspace[1],
|
||||||
|
)
|
||||||
|
cp_mha_gather_cache(
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
key=key_fetched,
|
||||||
|
value=value_fetched,
|
||||||
|
block_tables=block_table,
|
||||||
|
k_scales=k_scale,
|
||||||
|
v_scales=v_scale,
|
||||||
|
cu_seqlens_kv=swa_cu_seqlens,
|
||||||
|
token_to_batch=swa_token_to_batch,
|
||||||
|
seq_starts=swa_seq_starts,
|
||||||
|
dequant=False,
|
||||||
|
kv_cache_layout="NHD",
|
||||||
|
total_tokens=swa_total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
aiter.flash_attn_varlen_func(
|
||||||
|
q=query,
|
||||||
|
k=key_fetched,
|
||||||
|
v=value_fetched,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=swa_cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=swa_max_seqlens,
|
||||||
|
min_seqlen_q=1,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
return_lse=False,
|
||||||
|
out=output,
|
||||||
|
)
|
||||||
|
|
||||||
def extend_forward(
|
def extend_forward(
|
||||||
self,
|
self,
|
||||||
attn_metadata: AiterFlashAttentionMetadata,
|
attn_metadata: AiterFlashAttentionMetadata,
|
||||||
@@ -540,6 +671,20 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
k_scale: float,
|
k_scale: float,
|
||||||
v_scale: float,
|
v_scale: float,
|
||||||
):
|
):
|
||||||
|
if self.sliding_window[0] != -1:
|
||||||
|
self.extend_for_sliding_window(
|
||||||
|
attn_metadata,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
output,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
block_table,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
return
|
||||||
out, lse = aiter.flash_attn_varlen_func(
|
out, lse = aiter.flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
@@ -782,6 +927,36 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# calculate for decodes
|
# calculate for decodes
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
|
if self.sliding_window[0] != -1:
|
||||||
|
from aiter.ops.triton.unified_attention import (
|
||||||
|
unified_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
descale_shape = (
|
||||||
|
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
|
||||||
|
key_cache.shape[2],
|
||||||
|
)
|
||||||
|
unified_attention(
|
||||||
|
q=query[:num_decode_tokens],
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=output[:num_decode_tokens],
|
||||||
|
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
|
||||||
|
max_seqlen_q=1, # optimize this
|
||||||
|
seqused_k=attn_metadata.seq_lens[:num_decodes],
|
||||||
|
max_seqlen_k=attn_metadata.max_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=attn_metadata.block_table[:num_decodes],
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
assert attn_metadata.decode_metadata is not None
|
||||||
_, num_heads, head_size = query.shape
|
_, num_heads, head_size = query.shape
|
||||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user