[Core/Bugfix] Add FP8 K/V Scale and dtype conversion for prefix/prefill Triton Kernel (#7208)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||
@@ -18,12 +19,14 @@ CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
||||
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||
@torch.inference_mode()
|
||||
@@ -33,6 +36,7 @@ def test_contexted_kv_attention(
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
@@ -67,16 +71,20 @@ def test_contexted_kv_attention(
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
if kv_cache_dtype == "auto":
|
||||
cache_dtype = dtype
|
||||
else:
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
dtype=cache_dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
dtype=cache_dtype)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
@@ -132,6 +140,7 @@ def test_contexted_kv_attention(
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
@@ -146,6 +155,7 @@ def test_contexted_kv_attention(
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
@@ -208,13 +218,15 @@ def test_contexted_kv_attention(
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
output_ref = output_ref.reshape(output.shape)
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_alibi(
|
||||
@@ -222,6 +234,7 @@ def test_contexted_kv_attention_alibi(
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
@@ -282,17 +295,20 @@ def test_contexted_kv_attention_alibi(
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
if kv_cache_dtype == "auto":
|
||||
cache_dtype = dtype
|
||||
else:
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
dtype=cache_dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
dtype=cache_dtype)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
@@ -348,6 +364,7 @@ def test_contexted_kv_attention_alibi(
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
@@ -362,6 +379,7 @@ def test_contexted_kv_attention_alibi(
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
@@ -447,4 +465,5 @@ def test_contexted_kv_attention_alibi(
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user