[Core/Bugfix] Add FP8 K/V Scale and dtype conversion for prefix/prefill Triton Kernel (#7208)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
jon-chuang
2024-08-12 15:47:41 -07:00
committed by GitHub
parent 4ddc4743d7
commit a046f86397
10 changed files with 208 additions and 47 deletions

View File

@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
@@ -18,12 +19,14 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@torch.inference_mode()
@@ -33,6 +36,7 @@ def test_contexted_kv_attention(
head_size: int,
sliding_window: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(0)
@@ -67,16 +71,20 @@ def test_contexted_kv_attention(
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
@@ -132,6 +140,7 @@ def test_contexted_kv_attention(
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
@@ -146,6 +155,7 @@ def test_contexted_kv_attention(
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
@@ -208,13 +218,15 @@ def test_contexted_kv_attention(
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
@@ -222,6 +234,7 @@ def test_contexted_kv_attention_alibi(
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(0)
@@ -282,17 +295,20 @@ def test_contexted_kv_attention_alibi(
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
@@ -348,6 +364,7 @@ def test_contexted_kv_attention_alibi(
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
@@ -362,6 +379,7 @@ def test_contexted_kv_attention_alibi(
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
@@ -447,4 +465,5 @@ def test_contexted_kv_attention_alibi(
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)