Add support for a rope extension method (#6553)
This commit is contained in:
@@ -151,6 +151,15 @@ class ModelConfig:
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
|
||||
and getattr(self.hf_config, "rope_scaling", None) is None):
|
||||
# Note(simon): this is a special case for a model that doesn't
|
||||
# supply rope_scaling. We should remove this once the model is
|
||||
# updated.
|
||||
self.hf_config.update({"rope_scaling": {
|
||||
"type": "extended",
|
||||
}})
|
||||
|
||||
if (not self.disable_sliding_window
|
||||
and self.hf_text_config.model_type == "gemma2"
|
||||
and self.hf_text_config.sliding_window is not None):
|
||||
@@ -1442,8 +1451,9 @@ def _get_and_verify_max_len(
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
# The correct one should be "longrope", kept "su" here
|
||||
# to be backward compatible
|
||||
if rope_scaling is not None and rope_scaling["type"] != "su" \
|
||||
and rope_scaling["type"] != "longrope":
|
||||
if rope_scaling is not None and rope_scaling["type"] not in {
|
||||
"su", "longrope", "extended"
|
||||
}:
|
||||
if disable_sliding_window:
|
||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||
# with sliding window to see if this case should be allowed.
|
||||
|
||||
Reference in New Issue
Block a user