[model] update rope kwargs for yarn (#8101)

This commit is contained in:
piamo 2025-05-19 20:07:54 +08:00 committed by GitHub
parent ed2f89efaf
commit a6f3adf930
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,7 +56,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
if model_args.rope_scaling == RopeScaling.DYNAMIC:
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["original_max_position_embeddings"] = current_max_length