Support FP8-E5M2 KV Cache (#2279)

Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
zhaoyang-star
2024-01-29 08:43:54 +08:00
committed by GitHub
parent 7d648418b8
commit 9090bf02e7
26 changed files with 912 additions and 196 deletions

View File

@@ -1,9 +1,11 @@
from typing import Optional
import argparse
import random
import time
import torch
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops
NUM_BLOCKS = 1024
@@ -23,6 +25,7 @@ def main(
dtype: torch.dtype,
seed: int,
do_profile: bool,
kv_cache_dtype: Optional[str] = None,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
@@ -59,15 +62,10 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
key_cache.uniform_(-scale, scale)
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
dtype)
key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
@@ -106,6 +104,7 @@ def main(
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
elif version == "v2":
ops.paged_attention_v2(
@@ -123,6 +122,7 @@ def main(
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise ValueError(f"Invalid version: {version}")
@@ -168,16 +168,18 @@ if __name__ == '__main__':
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args()
print(args)
if args.num_query_heads % args.num_kv_heads != 0:
raise ValueError("num_query_heads must be divisible by num_kv_heads")
dtype_to_torch_dtype = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
main(
version=args.version,
num_seqs=args.batch_size,
@@ -187,7 +189,8 @@ if __name__ == '__main__':
head_size=args.head_size,
block_size=args.block_size,
use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype],
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
kv_cache_dtype=args.kv_cache_dtype,
)