[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu
2024-05-22 13:28:20 -07:00
committed by GitHub
parent 8674f9880e
commit a3a73ab069
40 changed files with 284 additions and 158 deletions

View File

@@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
# Attention output projection.
self.o_proj = RowParallelLinear(