[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)

This commit is contained in:
Cody Yu
2024-05-09 17:04:17 -07:00
committed by GitHub
parent 379da6dcb5
commit c833101740
17 changed files with 843 additions and 558 deletions

View File

@@ -236,14 +236,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(key_cache, dequantized_key_cache)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(value_cache, dequantized_value_cache)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)