[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
6e650f56a1
commit
e97f802b2d
@@ -182,7 +182,7 @@ def test_paged_attention(
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@@ -210,7 +210,7 @@ def test_paged_attention(
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
tp_rank = 0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
|
||||
@@ -160,7 +160,7 @@ def test_reshape_and_cache(
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
|
||||
@@ -258,8 +258,8 @@ def test_reshape_and_cache_flash(
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = key.amax().item() / 256
|
||||
v_scale = value.amax().item() / 256
|
||||
k_scale = (key.amax() / 256.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 256.0).to(torch.float32)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
@@ -284,12 +284,12 @@ def test_reshape_and_cache_flash(
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache,
|
||||
key_cache,
|
||||
k_scale,
|
||||
k_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache,
|
||||
value_cache,
|
||||
v_scale,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
|
||||
# Run the reference implementation.
|
||||
|
||||
@@ -138,6 +138,7 @@ def test_contexted_kv_attention(
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
@@ -153,6 +154,8 @@ def test_contexted_kv_attention(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
@@ -168,6 +171,8 @@ def test_contexted_kv_attention(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
@@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi(
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
@@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
@@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
@@ -909,6 +909,7 @@ def make_test_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
@@ -958,6 +959,7 @@ def make_test_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=kv_mmap.slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
|
||||
Reference in New Issue
Block a user