[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu
2024-05-22 13:28:20 -07:00
committed by GitHub
parent 8674f9880e
commit a3a73ab069
40 changed files with 284 additions and 158 deletions

View File

@@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
args = parser.parse_args()
print(args)