Add fp8 support to reshape_and_cache_flash (#6667)
This commit is contained in:
@@ -215,8 +215,6 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8":
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
@@ -248,15 +246,33 @@ def test_reshape_and_cache_flash(
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
key_cache, value_cache = key_caches[0].contiguous(
|
||||
), value_caches[0].contiguous()
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype)
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache, key_cache)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache, value_cache)
|
||||
|
||||
# Run the reference implementation.
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
@@ -269,8 +285,18 @@ def test_reshape_and_cache_flash(
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
if kv_cache_dtype == "fp8":
|
||||
assert torch.allclose(result_key_cache,
|
||||
cloned_key_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
assert torch.allclose(result_value_cache,
|
||||
cloned_value_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||
|
||||
Reference in New Issue
Block a user