mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[model] update rope kwargs for yarn (#8101)
This commit is contained in:
parent
ed2f89efaf
commit
a6f3adf930
@ -56,7 +56,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
if model_args.rope_scaling == RopeScaling.DYNAMIC:
|
||||
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
elif model_args.rope_scaling == RopeScaling.LLAMA3:
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
|
Loading…
x
Reference in New Issue
Block a user