[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-01-23 13:04:03 -05:00
committed by GitHub
parent 6e650f56a1
commit e97f802b2d
60 changed files with 276 additions and 1365 deletions

View File

@@ -138,6 +138,7 @@ def test_contexted_kv_attention(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
@@ -153,6 +154,8 @@ def test_contexted_kv_attention(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
start_time = time.time()
@@ -168,6 +171,8 @@ def test_contexted_kv_attention(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
end_time = time.time()
@@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
@@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
@@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()