[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu
2024-05-22 13:28:20 -07:00
committed by GitHub
parent 8674f9880e
commit a3a73ab069
40 changed files with 284 additions and 158 deletions

View File

@@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,