mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
[config] Fix RoPE scaling patch for resuming from a scaled model (#9588)
This commit is contained in:
@@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") ->
|
|||||||
logger.warning_rank0("Current model does not support RoPE scaling.")
|
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling:
|
||||||
|
old_max_length = rope_scaling["original_max_position_embeddings"]
|
||||||
|
elif hasattr(config, "max_position_embeddings"):
|
||||||
old_max_length = getattr(config, "max_position_embeddings", None)
|
old_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
else:
|
else:
|
||||||
logger.warning_rank0("Cannot find the max position embeddings in the config.")
|
logger.warning_rank0("Cannot find the max position embeddings in the config.")
|
||||||
|
|||||||
Reference in New Issue
Block a user