[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu
2024-05-22 13:28:20 -07:00
committed by GitHub
parent 8674f9880e
commit a3a73ab069
40 changed files with 284 additions and 158 deletions

View File

@@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,