[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
6e650f56a1
commit
e97f802b2d
@@ -138,6 +138,7 @@ def test_contexted_kv_attention(
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
@@ -153,6 +154,8 @@ def test_contexted_kv_attention(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
@@ -168,6 +171,8 @@ def test_contexted_kv_attention(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
@@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi(
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
@@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
@@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
Reference in New Issue
Block a user