[kernel] fix sliding window in prefix prefill Triton kernel (#4405)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Michał Moskal
2024-05-02 11:23:37 -07:00
committed by GitHub
parent 5b8a7c1cb0
commit 32881f3f31
6 changed files with 91 additions and 23 deletions

View File

@@ -15,6 +15,7 @@ DTYPES = [torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -22,11 +23,13 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
sliding_window: int,
dtype: torch.dtype,
device: str,
) -> None:
@@ -123,12 +126,32 @@ def test_contexted_kv_attention(
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
context_attention_fwd(query,
k,
v,
output,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
sliding_window=sliding_window)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
context_attention_fwd(query,
k,
v,
output,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
sliding_window=sliding_window)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
@@ -156,6 +179,9 @@ def test_contexted_kv_attention(
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,