[Bugfix/Core] Flashinfer k_scale and v_scale (#9861)

This commit is contained in:
Pavani Majety
2024-11-01 12:15:05 -07:00
committed by GitHub
parent aff1fd8188
commit 598b6d7b07
3 changed files with 25 additions and 12 deletions

View File

@@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = key.amax().item() / 256
v_scale = value.amax().item() / 256
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
@@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
ops.convert_fp8(result_key_cache,
key_cache,
k_scale,
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
ops.convert_fp8(result_value_cache,
value_cache,
v_scale,
kv_dtype=kv_cache_dtype)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")