[Bugfix] Change kv scaling factor by param json on nvidia gpu (#11688)
Signed-off-by: bjmsong <bjmsong@126.com> Co-authored-by: bjmsong <bjmsong@126.com>
This commit is contained in:
@@ -606,8 +606,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
@@ -545,8 +545,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
@@ -452,8 +452,9 @@ class LlamaModel(nn.Module):
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
@@ -565,8 +565,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
Reference in New Issue
Block a user