[Bugfix] Fix MRotaryEmbedding missing truncate attr with YaRN scaling (#35080)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -218,12 +218,14 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
beta_fast: int = 32,
|
beta_fast: int = 32,
|
||||||
beta_slow: int = 1,
|
beta_slow: int = 1,
|
||||||
|
truncate: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.extrapolation_factor = extrapolation_factor
|
self.extrapolation_factor = extrapolation_factor
|
||||||
self.attn_factor = attn_factor
|
self.attn_factor = attn_factor
|
||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
|
self.truncate = truncate
|
||||||
if self.scaling_factor is not None:
|
if self.scaling_factor is not None:
|
||||||
# Get n-d magnitude scaling corrected for interpolation
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
|
|||||||
Reference in New Issue
Block a user