Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: HaiShaw <hixiao@gmail.com> Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu> Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com> Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com> Co-authored-by: guofangze <guofangze@kuaishou.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -97,6 +97,9 @@ def main(
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
ops.paged_attention_v1(
|
||||
@@ -112,6 +115,7 @@ def main(
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
@@ -130,6 +134,7 @@ def main(
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
@@ -179,11 +184,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
choices=["auto", "fp8"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user