Support FP8-E5M2 KV Cache (#2279)

Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
zhaoyang-star
2024-01-29 08:43:54 +08:00
committed by GitHub
parent 7d648418b8
commit 9090bf02e7
26 changed files with 912 additions and 196 deletions

View File

@@ -1,13 +1,14 @@
from typing import Optional, Union, ClassVar
from dataclasses import dataclass
import os
from packaging.version import Version
import torch
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
logger = init_logger(__name__)
@@ -275,6 +276,7 @@ class CacheConfig:
gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
"""
def __init__(
@@ -282,13 +284,16 @@ class CacheConfig:
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self._verify_args()
self._verify_cache_dtype()
# Will be set after profiling.
self.num_gpu_blocks = None
@@ -300,6 +305,28 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"):
raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8."
)
device_name = torch.cuda.get_device_name()
if "AMD" in device_name:
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
logger.info(
"Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",