Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)

Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: guofangze <guofangze@kuaishou.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Adrian Abeyta
2024-04-03 16:15:55 -05:00
committed by GitHub
parent 3dcb3e8b98
commit 2ff767b513
41 changed files with 2592 additions and 142 deletions

View File

@@ -97,6 +97,9 @@ def main(
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()
# Using default kv_scale
kv_scale = 1.0
for _ in range(num_iters):
if version == "v1":
ops.paged_attention_v1(
@@ -112,6 +115,7 @@ def main(
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
@@ -130,6 +134,7 @@ def main(
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
@@ -179,11 +184,13 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args()
print(args)