[ROCm][Quantization][Kernel] Using HIP FP8 header (#12593)

This commit is contained in:
Gregory Shtrasberg
2025-02-25 03:39:59 -05:00
committed by GitHub
parent 2f42a4888c
commit aabeb2688f
6 changed files with 267 additions and 634 deletions

View File

@@ -159,19 +159,20 @@ def test_reshape_and_cache(
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
@@ -182,9 +183,9 @@ def test_reshape_and_cache(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@@ -268,15 +269,16 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = (value.amax() / 256.0).to(torch.float32)
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()