From cc703b58f5c6539733f0087ff27c0be93dd19534 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Fri, 20 Dec 2024 16:43:03 +0800 Subject: [PATCH] fix: by hiyouga suggestion Former-commit-id: 3a7ea2048a41eafc41fdca944e142f5a0f35a5b3 --- src/llamafactory/hparams/finetuning_args.py | 4 ++-- src/llamafactory/train/dpo/trainer.py | 5 ++++- src/llamafactory/train/kto/trainer.py | 5 ++++- src/llamafactory/train/ppo/trainer.py | 5 ++++- src/llamafactory/train/pt/trainer.py | 5 ++++- src/llamafactory/train/rm/trainer.py | 5 ++++- src/llamafactory/webui/locales.py | 2 +- 7 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index b39e9f18..4765e1a4 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -308,10 +308,10 @@ class BAdamArgument: class SwanLabArguments: use_swanlab: bool = field( default=False, - metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."}, + metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, ) swanlab_project: str = field( - default=None, + default="LLaMA Factory", metadata={"help": "The project name in SwanLab."}, ) swanlab_workspace: str = field( diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 7e76dee2..0115f834 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -31,7 +31,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -106,6 +106,9 @@ class CustomDPOTrainer(DPOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index e22b16a4..71802375 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -101,6 +101,9 @@ class CustomKTOTrainer(KTOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 4ab7a118..a60b7d7c 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -40,7 +40,7 @@ from typing_extensions import override from ...extras import logging from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -186,6 +186,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 37dcadfd..5e4a513d 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -20,7 +20,7 @@ from typing_extensions import override from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -56,6 +56,9 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index bccfdef5..458c40ff 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -27,7 +27,7 @@ from typing_extensions import override from ...extras import logging from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -68,6 +68,9 @@ class PairwiseTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index e267b63c..8b5baade 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1438,7 +1438,7 @@ LOCALES = { }, "swanlab_experiment_name": { "en": { - "label": "Experiment_name(optional)", + "label": "Experiment name (optional)", }, "ru": { "label": "Имя эксперимента(Необязательный)",