diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index d04279e0..b217735b 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> logger.warning_rank0("Current model does not support RoPE scaling.") return - if hasattr(config, "max_position_embeddings"): + rope_scaling = getattr(config, "rope_scaling", None) + if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling: + old_max_length = rope_scaling["original_max_position_embeddings"] + elif hasattr(config, "max_position_embeddings"): old_max_length = getattr(config, "max_position_embeddings", None) else: logger.warning_rank0("Cannot find the max position embeddings in the config.")