diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 8cfea728..faf786d5 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -305,7 +305,21 @@ class BAdamArgument: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): +class SwanLabArguments: + use_swanlab: bool = field( + default=False, + metadata={"help": ""}, + ) + swanlab_name: str = field( + default="", + metadata={}, + ) + + +@dataclass +class FinetuningArguments( + FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments +): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0f118bbb..be095f65 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -29,7 +29,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -71,6 +71,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 7d916ec1..d33bc538 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -40,7 +40,7 @@ if is_galore_available(): if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments + from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments @@ -457,3 +457,12 @@ def get_batch_logps( labels[labels == label_pad_token_id] = 0 # dummy token per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) + + +def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": + r""" + Gets the callback for logging to SwanLab. + """ + from swanlab.integration.huggingface import SwanLabCallback + + return SwanLabCallback() diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index bd53d163..bc776acd 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -270,6 +270,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) ) + with gr.Accordion(open=False) as swanlab_tab: + with gr.Row(): + use_swanlab = gr.Checkbox() + + input_elems.update({use_swanlab}) + elem_dict.update( + dict( + swanlab_tab=swanlab_tab, + use_swanlab=use_swanlab, + ) + ) + with gr.Row(): cmd_preview_btn = gr.Button() arg_save_btn = gr.Button() diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 45b847b4..c88800d2 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1353,6 +1353,20 @@ LOCALES = { "info": "비율-BAdam의 업데이트 비율.", }, }, + "swanlab_tab": { + "en": { + "label": "SwanLab configurations", + }, + "ru": { + "label": "Конфигурации SwanLab", + }, + "zh": { + "label": "SwanLab 参数设置", + }, + "ko": { + "label": "SwanLab 설정", + }, + }, "cmd_preview_btn": { "en": { "value": "Preview command", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index da0a9c7e..7bd61390 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -147,6 +147,7 @@ class Runner: report_to="all" if get("train.report_to") else "none", use_galore=get("train.use_galore"), use_badam=get("train.use_badam"), + use_swanlab=get("train.use_swanlab"), output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), fp16=(get("train.compute_type") == "fp16"), bf16=(get("train.compute_type") == "bf16"), @@ -228,6 +229,10 @@ class Runner: args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") + # swanlab config + if get("train.use_swanlab"): + args["swanlab_name"] = get("train.swanlab_name") + # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo": args["val_size"] = get("train.val_size")