From 1a483406809791c3bde2187c9bff7838c5857051 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 19 Dec 2024 07:12:31 +0000 Subject: [PATCH 01/10] add swanlab Former-commit-id: 96f8f103e58a8ff307b0ce36c967de04f452434a --- src/llamafactory/hparams/finetuning_args.py | 16 +++++++++++++++- src/llamafactory/train/sft/trainer.py | 5 ++++- src/llamafactory/train/trainer_utils.py | 11 ++++++++++- src/llamafactory/webui/components/train.py | 12 ++++++++++++ src/llamafactory/webui/locales.py | 14 ++++++++++++++ src/llamafactory/webui/runner.py | 5 +++++ 6 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 8cfea728..faf786d5 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -305,7 +305,21 @@ class BAdamArgument: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): +class SwanLabArguments: + use_swanlab: bool = field( + default=False, + metadata={"help": ""}, + ) + swanlab_name: str = field( + default="", + metadata={}, + ) + + +@dataclass +class FinetuningArguments( + FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments +): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0f118bbb..be095f65 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -29,7 +29,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -71,6 +71,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 7d916ec1..d33bc538 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -40,7 +40,7 @@ if is_galore_available(): if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments + from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments @@ -457,3 +457,12 @@ def get_batch_logps( labels[labels == label_pad_token_id] = 0 # dummy token per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) + + +def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": + r""" + Gets the callback for logging to SwanLab. + """ + from swanlab.integration.huggingface import SwanLabCallback + + return SwanLabCallback() diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index bd53d163..bc776acd 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -270,6 +270,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) ) + with gr.Accordion(open=False) as swanlab_tab: + with gr.Row(): + use_swanlab = gr.Checkbox() + + input_elems.update({use_swanlab}) + elem_dict.update( + dict( + swanlab_tab=swanlab_tab, + use_swanlab=use_swanlab, + ) + ) + with gr.Row(): cmd_preview_btn = gr.Button() arg_save_btn = gr.Button() diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 45b847b4..c88800d2 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1353,6 +1353,20 @@ LOCALES = { "info": "비율-BAdam의 업데이트 비율.", }, }, + "swanlab_tab": { + "en": { + "label": "SwanLab configurations", + }, + "ru": { + "label": "Конфигурации SwanLab", + }, + "zh": { + "label": "SwanLab 参数设置", + }, + "ko": { + "label": "SwanLab 설정", + }, + }, "cmd_preview_btn": { "en": { "value": "Preview command", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index da0a9c7e..7bd61390 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -147,6 +147,7 @@ class Runner: report_to="all" if get("train.report_to") else "none", use_galore=get("train.use_galore"), use_badam=get("train.use_badam"), + use_swanlab=get("train.use_swanlab"), output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), fp16=(get("train.compute_type") == "fp16"), bf16=(get("train.compute_type") == "bf16"), @@ -228,6 +229,10 @@ class Runner: args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") + # swanlab config + if get("train.use_swanlab"): + args["swanlab_name"] = get("train.swanlab_name") + # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo": args["val_size"] = get("train.val_size") From cc5cde734bb7b5e713bb2c5cf5ba07e4f365d3e0 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 18:47:27 +0800 Subject: [PATCH 02/10] feat: swanlab params Former-commit-id: d5cf87990e5bea920ecd1561def09fa17cf328b1 --- src/llamafactory/hparams/finetuning_args.py | 26 ++++++++++++++++++--- src/llamafactory/train/trainer_utils.py | 16 +++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index faf786d5..59194329 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -308,11 +308,31 @@ class BAdamArgument: class SwanLabArguments: use_swanlab: bool = field( default=False, - metadata={"help": ""}, + metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."}, ) - swanlab_name: str = field( + swanlab_project: str = field( default="", - metadata={}, + metadata={"help": "The project name in SwanLab."}, + ) + swanlab_workspace: str = field( + default="", + metadata={"help": "The workspace name in SwanLab."}, + ) + swanlab_experiment_name: str = field( + default="", + metadata={"help": "The experiment name in SwanLab."}, + ) + swanlab_description: str = field( + default="", + metadata={"help": "The experiment description in SwanLab."}, + ) + swanlab_mode: Literal["cloud", "local", "disabled"] = field( + default="cloud", + metadata={"help": "The mode of SwanLab."}, + ) + swanlab_api_key: str = field( + default="", + metadata={"help": "The API key for SwanLab."}, ) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index d33bc538..e13fe552 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -463,6 +463,18 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall r""" Gets the callback for logging to SwanLab. """ - from swanlab.integration.huggingface import SwanLabCallback + import swanlab + from swanlab.integration.transformers import SwanLabCallback + + if finetuning_args.swanlab_api_key is not None: + swanlab.login(api_key=finetuning_args.swanlab_api_key) - return SwanLabCallback() + swanlab_callback = SwanLabCallback( + project=finetuning_args.swanlab_project, + workspace=finetuning_args.swanlab_workspace, + experiment_name=finetuning_args.swanlab_experiment_name, + description=finetuning_args.swanlab_description, + mode=finetuning_args.swanlab_mode, + ) + + return swanlab_callback \ No newline at end of file From 53103f55b695bd40b78e4329f6f304d3b60b0743 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 19:04:19 +0800 Subject: [PATCH 03/10] feat: optimize frontend Former-commit-id: 8c2df41b937f491f7ebf593b20c65a19738c7642 --- src/llamafactory/train/trainer_utils.py | 2 +- src/llamafactory/webui/components/train.py | 12 +++ src/llamafactory/webui/locales.py | 106 +++++++++++++++++++++ src/llamafactory/webui/runner.py | 8 +- 4 files changed, 126 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index e13fe552..583ca3ee 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -465,7 +465,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall """ import swanlab from swanlab.integration.transformers import SwanLabCallback - + if finetuning_args.swanlab_api_key is not None: swanlab.login(api_key=finetuning_args.swanlab_api_key) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index bc776acd..f905e43e 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -273,12 +273,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as swanlab_tab: with gr.Row(): use_swanlab = gr.Checkbox() + swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) + swanlab_project = gr.Textbox(value="", placeholder="Project name", interactive=True) + swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) + swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) + swanlab_description = gr.Textbox(value="", placeholder="Experiment description", interactive=True) + swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) input_elems.update({use_swanlab}) elem_dict.update( dict( swanlab_tab=swanlab_tab, use_swanlab=use_swanlab, + swanlab_api_key=swanlab_api_key, + swanlab_project=swanlab_project, + swanlab_workspace=swanlab_workspace, + swanlab_experiment_name=swanlab_experiment_name, + swanlab_description=swanlab_description, + swanlab_mode=swanlab_mode, ) ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index c88800d2..2167bfd1 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1367,6 +1367,112 @@ LOCALES = { "label": "SwanLab 설정", }, }, + "use_swanlab": { + "en": { + "label": "Use SwanLab", + "info": "Enable SwanLab for experiment tracking and visualization.", + }, + "ru": { + "label": "Использовать SwanLab", + "info": "Включить SwanLab для отслеживания и визуализации экспериментов.", + }, + "zh": { + "label": "使用 SwanLab", + "info": "启用 SwanLab 进行实验跟踪和可视化。", + }, + "ko": { + "label": "SwanLab 사용", + "info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.", + }, + }, + "swanlab_api_key": { + "en": { + "label": "API key", + "info": "The API key for SwanLab.", + }, + "ru": { + "label": "API ключ", + "info": "API ключ для SwanLab.", + }, + "zh": { + "label": "API 密钥", + "info": "SwanLab 的 API 密钥。", + }, + "ko": { + "label": "API 키", + "info": "SwanLab의 API 키.", + }, + }, + "swanlab_project": { + "en": { + "label": "SwanLab project", + }, + "ru": { + "label": "Проект SwanLab", + }, + "zh": { + "label": "SwanLab 项目", + }, + "ko": { + "label": "SwanLab 프로젝트", + }, + }, + "swanlab_workspace": { + "en": { + "label": "SwanLab workspace", + }, + "ru": { + "label": "Рабочая область SwanLab", + }, + "zh": { + "label": "SwanLab 工作区", + }, + "ko": { + "label": "SwanLab 작업 영역", + }, + }, + "swanlab_experiment_name": { + "en": { + "label": "SwanLab experiment name", + }, + "ru": { + "label": "Имя эксперимента SwanLab", + }, + "zh": { + "label": "SwanLab 实验名称", + }, + "ko": { + "label": "SwanLab 실험 이름", + }, + }, + "swanlab_description": { + "en": { + "label": "SwanLab experiment description", + }, + "ru": { + "label": "Описание эксперимента SwanLab", + }, + "zh": { + "label": "SwanLab 实验描述", + }, + "ko": { + "label": "SwanLab 실험 설명", + }, + }, + "swanlab_mode": { + "en": { + "label": "SwanLab mode", + }, + "ru": { + "label": "Режим SwanLab", + }, + "zh": { + "label": "SwanLab 模式", + }, + "ko": { + "label": "SwanLab 모드", + }, + }, "cmd_preview_btn": { "en": { "value": "Preview command", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 7bd61390..6c2b4feb 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -231,7 +231,13 @@ class Runner: # swanlab config if get("train.use_swanlab"): - args["swanlab_name"] = get("train.swanlab_name") + args["swanlab_api_key"] = get("train.swanlab_api_key") + args["swanlab_project"] = get("train.swanlab_project") + args["swanlab_workspace"] = get("train.swanlab_workspace") + args["swanlab_experiment_name"] = get("train.swanlab_experiment_name") + args["swanlab_description"] = get("train.swanlab_description") + args["swanlab_mode"] = get("train.swanlab_mode") + # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo": From c31933ef9eb147942cf38a20edea9a6d7cabe2aa Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 20:18:59 +0800 Subject: [PATCH 04/10] fix: string Former-commit-id: 330691962960fdd2053236e43a919e8f15e2bf27 --- src/llamafactory/hparams/finetuning_args.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 59194329..72b6a0a9 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -311,19 +311,19 @@ class SwanLabArguments: metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."}, ) swanlab_project: str = field( - default="", + default=None, metadata={"help": "The project name in SwanLab."}, ) swanlab_workspace: str = field( - default="", + default=None, metadata={"help": "The workspace name in SwanLab."}, ) swanlab_experiment_name: str = field( - default="", + default=None, metadata={"help": "The experiment name in SwanLab."}, ) swanlab_description: str = field( - default="", + default=None, metadata={"help": "The experiment description in SwanLab."}, ) swanlab_mode: Literal["cloud", "local", "disabled"] = field( @@ -331,7 +331,7 @@ class SwanLabArguments: metadata={"help": "The mode of SwanLab."}, ) swanlab_api_key: str = field( - default="", + default=None, metadata={"help": "The API key for SwanLab."}, ) From b512a06c3d2aac2391f7f5a8988facda9c7e7452 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 20:22:36 +0800 Subject: [PATCH 05/10] docs: config framework Former-commit-id: 7eb49e5ffaea59d8a2756ae7ff55bd57b9077f4b --- src/llamafactory/hparams/finetuning_args.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 72b6a0a9..aaa28a6a 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -334,6 +334,10 @@ class SwanLabArguments: default=None, metadata={"help": "The API key for SwanLab."}, ) + swanlab_config: dict = field( + default={"Framework": "🦙LLaMA Factory"}, + metadata={"help": "The configuration file for SwanLab."}, + ) @dataclass From dd22454fc5147d3d2778bb140ae50339830cf911 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 21:08:16 +0800 Subject: [PATCH 06/10] fix: bugs Former-commit-id: d0eb64d5e3472a166c9adac4cb4ba06bdd663e46 --- src/llamafactory/hparams/finetuning_args.py | 4 ---- src/llamafactory/train/trainer_utils.py | 1 + src/llamafactory/webui/components/train.py | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index aaa28a6a..72b6a0a9 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -334,10 +334,6 @@ class SwanLabArguments: default=None, metadata={"help": "The API key for SwanLab."}, ) - swanlab_config: dict = field( - default={"Framework": "🦙LLaMA Factory"}, - metadata={"help": "The configuration file for SwanLab."}, - ) @dataclass diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 583ca3ee..2a50bf71 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -475,6 +475,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall experiment_name=finetuning_args.swanlab_experiment_name, description=finetuning_args.swanlab_description, mode=finetuning_args.swanlab_mode, + config={"Framework": "🦙LLaMA Factory"}, ) return swanlab_callback \ No newline at end of file diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index f905e43e..6b14bb81 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -280,7 +280,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: swanlab_description = gr.Textbox(value="", placeholder="Experiment description", interactive=True) swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) - input_elems.update({use_swanlab}) + input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_description, swanlab_mode}) elem_dict.update( dict( swanlab_tab=swanlab_tab, From 03dba638e6cd91fde5a7ce0b427f1cd6bd1adddd Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 21:26:02 +0800 Subject: [PATCH 07/10] fix: text Former-commit-id: 0a52962db365e7456c858a8e58c19313f19d1e09 --- src/llamafactory/webui/components/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 6b14bb81..d7237bd6 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -273,12 +273,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as swanlab_tab: with gr.Row(): use_swanlab = gr.Checkbox() - swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) - swanlab_project = gr.Textbox(value="", placeholder="Project name", interactive=True) - swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) + swanlab_project = gr.Textbox(value="LLaMA-Factory", placeholder="Project name", interactive=True) swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) swanlab_description = gr.Textbox(value="", placeholder="Experiment description", interactive=True) swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) + swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) + swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_description, swanlab_mode}) elem_dict.update( From 8f786ee938e656589b5308dda6d1f54d1defccc3 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Fri, 20 Dec 2024 11:03:02 +0800 Subject: [PATCH 08/10] feat: ui improve Former-commit-id: 5f6dafd70e962b8fe9a294d555133002135f80df --- src/llamafactory/hparams/finetuning_args.py | 6 +- src/llamafactory/train/trainer_utils.py | 1 - src/llamafactory/webui/components/train.py | 6 +- src/llamafactory/webui/locales.py | 71 ++++++++++----------- src/llamafactory/webui/runner.py | 1 - 5 files changed, 36 insertions(+), 49 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 72b6a0a9..b39e9f18 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -322,11 +322,7 @@ class SwanLabArguments: default=None, metadata={"help": "The experiment name in SwanLab."}, ) - swanlab_description: str = field( - default=None, - metadata={"help": "The experiment description in SwanLab."}, - ) - swanlab_mode: Literal["cloud", "local", "disabled"] = field( + swanlab_mode: Literal["cloud", "local"] = field( default="cloud", metadata={"help": "The mode of SwanLab."}, ) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 2a50bf71..5b8fb403 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -473,7 +473,6 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall project=finetuning_args.swanlab_project, workspace=finetuning_args.swanlab_workspace, experiment_name=finetuning_args.swanlab_experiment_name, - description=finetuning_args.swanlab_description, mode=finetuning_args.swanlab_mode, config={"Framework": "🦙LLaMA Factory"}, ) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index d7237bd6..8766d5b9 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -275,12 +275,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: use_swanlab = gr.Checkbox() swanlab_project = gr.Textbox(value="LLaMA-Factory", placeholder="Project name", interactive=True) swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) - swanlab_description = gr.Textbox(value="", placeholder="Experiment description", interactive=True) - swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) + swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) - input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_description, swanlab_mode}) + input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode}) elem_dict.update( dict( swanlab_tab=swanlab_tab, @@ -289,7 +288,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: swanlab_project=swanlab_project, swanlab_workspace=swanlab_workspace, swanlab_experiment_name=swanlab_experiment_name, - swanlab_description=swanlab_description, swanlab_mode=swanlab_mode, ) ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 2167bfd1..e267b63c 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1387,90 +1387,85 @@ LOCALES = { }, "swanlab_api_key": { "en": { - "label": "API key", - "info": "The API key for SwanLab.", + "label": "API Key(optional)", + "info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.", }, "ru": { - "label": "API ключ", - "info": "API ключ для SwanLab.", + "label": "API ключ(Необязательный)", + "info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.", }, "zh": { - "label": "API 密钥", - "info": "SwanLab 的 API 密钥。", + "label": "API密钥(选填)", + "info": "用于在编程环境登录SwanLab,已登录则无需填写。", }, "ko": { - "label": "API 키", - "info": "SwanLab의 API 키.", + "label": "API 키(선택 사항)", + "info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.", }, }, "swanlab_project": { "en": { - "label": "SwanLab project", + "label": "Project(optional)", }, "ru": { - "label": "Проект SwanLab", + "label": "Проект(Необязательный)", }, "zh": { - "label": "SwanLab 项目", + "label": "项目(选填)", }, "ko": { - "label": "SwanLab 프로젝트", + "label": "프로젝트(선택 사항)", }, }, "swanlab_workspace": { "en": { - "label": "SwanLab workspace", + "label": "Workspace(optional)", + "info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.", + }, "ru": { - "label": "Рабочая область SwanLab", + "label": "Рабочая область(Необязательный)", + "info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.", }, "zh": { - "label": "SwanLab 工作区", + "label": "Workspace(选填)", + "info": "SwanLab组织的工作区,如不填写则默认在个人工作区下", }, "ko": { - "label": "SwanLab 작업 영역", + "label": "작업 영역(선택 사항)", + "info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.", }, }, "swanlab_experiment_name": { "en": { - "label": "SwanLab experiment name", + "label": "Experiment_name(optional)", }, "ru": { - "label": "Имя эксперимента SwanLab", + "label": "Имя эксперимента(Необязательный)", }, "zh": { - "label": "SwanLab 实验名称", + "label": "实验名(选填) ", }, "ko": { - "label": "SwanLab 실험 이름", - }, - }, - "swanlab_description": { - "en": { - "label": "SwanLab experiment description", - }, - "ru": { - "label": "Описание эксперимента SwanLab", - }, - "zh": { - "label": "SwanLab 实验描述", - }, - "ko": { - "label": "SwanLab 실험 설명", + "label": "실험 이름(선택 사항)", }, }, "swanlab_mode": { "en": { - "label": "SwanLab mode", + "label": "Mode", + "info": "Cloud or offline version.", }, "ru": { - "label": "Режим SwanLab", + "label": "Режим", + "info": "Версия в облаке или локальная версия.", }, "zh": { - "label": "SwanLab 模式", + "label": "模式", + "info": "云端版或离线版", }, "ko": { - "label": "SwanLab 모드", + "label": "모드", + "info": "클라우드 버전 또는 오프라인 버전.", }, }, "cmd_preview_btn": { diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 6c2b4feb..2b5c55b8 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -235,7 +235,6 @@ class Runner: args["swanlab_project"] = get("train.swanlab_project") args["swanlab_workspace"] = get("train.swanlab_workspace") args["swanlab_experiment_name"] = get("train.swanlab_experiment_name") - args["swanlab_description"] = get("train.swanlab_description") args["swanlab_mode"] = get("train.swanlab_mode") From cc703b58f5c6539733f0087ff27c0be93dd19534 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Fri, 20 Dec 2024 16:43:03 +0800 Subject: [PATCH 09/10] fix: by hiyouga suggestion Former-commit-id: 3a7ea2048a41eafc41fdca944e142f5a0f35a5b3 --- src/llamafactory/hparams/finetuning_args.py | 4 ++-- src/llamafactory/train/dpo/trainer.py | 5 ++++- src/llamafactory/train/kto/trainer.py | 5 ++++- src/llamafactory/train/ppo/trainer.py | 5 ++++- src/llamafactory/train/pt/trainer.py | 5 ++++- src/llamafactory/train/rm/trainer.py | 5 ++++- src/llamafactory/webui/locales.py | 2 +- 7 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index b39e9f18..4765e1a4 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -308,10 +308,10 @@ class BAdamArgument: class SwanLabArguments: use_swanlab: bool = field( default=False, - metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."}, + metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, ) swanlab_project: str = field( - default=None, + default="LLaMA Factory", metadata={"help": "The project name in SwanLab."}, ) swanlab_workspace: str = field( diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 7e76dee2..0115f834 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -31,7 +31,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -106,6 +106,9 @@ class CustomDPOTrainer(DPOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index e22b16a4..71802375 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -101,6 +101,9 @@ class CustomKTOTrainer(KTOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 4ab7a118..a60b7d7c 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -40,7 +40,7 @@ from typing_extensions import override from ...extras import logging from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -186,6 +186,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 37dcadfd..5e4a513d 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -20,7 +20,7 @@ from typing_extensions import override from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -56,6 +56,9 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index bccfdef5..458c40ff 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -27,7 +27,7 @@ from typing_extensions import override from ...extras import logging from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -68,6 +68,9 @@ class PairwiseTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index e267b63c..8b5baade 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1438,7 +1438,7 @@ LOCALES = { }, "swanlab_experiment_name": { "en": { - "label": "Experiment_name(optional)", + "label": "Experiment name (optional)", }, "ru": { "label": "Имя эксперимента(Необязательный)", From 67d4757c35fa70e91adc784118f91d39e5706def Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Fri, 20 Dec 2024 18:26:02 +0800 Subject: [PATCH 10/10] fix: project blank Former-commit-id: 82e5d75014ffe5fbe762711adecf59c94ab29f59 --- src/llamafactory/hparams/finetuning_args.py | 2 +- src/llamafactory/webui/components/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 4765e1a4..4fa12ff3 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -311,7 +311,7 @@ class SwanLabArguments: metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, ) swanlab_project: str = field( - default="LLaMA Factory", + default="llamafactory", metadata={"help": "The project name in SwanLab."}, ) swanlab_workspace: str = field( diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 8766d5b9..399823d8 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -273,7 +273,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as swanlab_tab: with gr.Row(): use_swanlab = gr.Checkbox() - swanlab_project = gr.Textbox(value="LLaMA-Factory", placeholder="Project name", interactive=True) + swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True) swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True)