From a6f3adf930755be0f983fdea2e2926f86977f8ce Mon Sep 17 00:00:00 2001 From: piamo Date: Mon, 19 May 2025 20:07:54 +0800 Subject: [PATCH] [model] update rope kwargs for yarn (#8101) --- src/llamafactory/model/model_utils/rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index 30d0fdd7..29b56a0b 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -56,7 +56,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") setattr(config, "max_position_embeddings", model_args.model_max_length) rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) - if model_args.rope_scaling == RopeScaling.DYNAMIC: + if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]: rope_kwargs["original_max_position_embeddings"] = current_max_length elif model_args.rope_scaling == RopeScaling.LLAMA3: rope_kwargs["original_max_position_embeddings"] = current_max_length