[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-01-23 13:04:03 -05:00
committed by GitHub
parent 6e650f56a1
commit e97f802b2d
60 changed files with 276 additions and 1365 deletions

View File

@@ -160,7 +160,7 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
@@ -258,8 +258,8 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = key.amax().item() / 256
v_scale = value.amax().item() / 256
k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = (value.amax() / 256.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
@@ -284,12 +284,12 @@ def test_reshape_and_cache_flash(
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache,
k_scale,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache,
v_scale,
v_scale.item(),
kv_dtype=kv_cache_dtype)
# Run the reference implementation.