mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
tiny fix
Former-commit-id: eae79707d31fd8be2cf4bee4d610557bbd49f6e7
This commit is contained in:
parent
83fc73c580
commit
35d04a2c05
@ -69,7 +69,7 @@ def main():
|
|||||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
ppo_trainer.save_state() # must be after save_model
|
ppo_trainer.save_state() # must be after save_model
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args, keys=["loss", "reward"])
|
plot_loss(training_args, keys=["loss", "reward"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ def main():
|
|||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
@ -71,7 +71,7 @@ def main():
|
|||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
@ -91,7 +91,7 @@ def init_adapter(
|
|||||||
lastest_checkpoint = None
|
lastest_checkpoint = None
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if is_trainable and finetuning_args.resume_lora_training: # continually train on the lora weights
|
if is_trainable and model_args.resume_lora_training: # continually train on the lora weights
|
||||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
else:
|
else:
|
||||||
checkpoints_to_merge = model_args.checkpoint_dir
|
checkpoints_to_merge = model_args.checkpoint_dir
|
||||||
|
@ -51,6 +51,14 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
)
|
)
|
||||||
|
resume_lora_training: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||||
|
)
|
||||||
|
plot_loss: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.checkpoint_dir is not None: # support merging lora weights
|
if self.checkpoint_dir is not None: # support merging lora weights
|
||||||
@ -173,14 +181,6 @@ class FinetuningArguments:
|
|||||||
default="q_proj,v_proj",
|
default="q_proj,v_proj",
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
|
||||||
)
|
)
|
||||||
resume_lora_training: Optional[bool] = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.lora_target, str):
|
if isinstance(self.lora_target, str):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user