Merge pull request #6418 from hiyouga/hiyouga/add_report

[trainer] add custom args to experimental logger

Former-commit-id: d58746eca203d97ec57abbc312ecf4c00b5d5535
This commit is contained in:
hoshi-hiyouga 2024-12-22 05:47:55 +08:00 committed by GitHub
commit c0418062c0
20 changed files with 164 additions and 124 deletions

1
.gitignore vendored
View File

@ -171,4 +171,5 @@ config/
saves/ saves/
output/ output/
wandb/ wandb/
swanlog/
generated_predictions.jsonl generated_predictions.jsonl

View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-93-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -13,6 +13,7 @@
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) [![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@ -87,18 +88,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/12/21] We supported **[SwanLab](https://github.com/SwanHubX/SwanLab)** experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. [24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
<details><summary>Full Changelog</summary>
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
<details><summary>Full Changelog</summary>
[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training. [24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR. [24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
@ -388,7 +389,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, swanlab, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@ -642,8 +643,7 @@ To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental r
```yaml ```yaml
use_swanlab: true use_swanlab: true
swanlab_project: test_project # optional swanlab_run_name: test_run # optional
swanlab_experiment_name: test_experiment # optional
``` ```
When launching training tasks, you can log in to SwanLab in three ways: When launching training tasks, you can log in to SwanLab in three ways:

View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-93-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -13,6 +13,7 @@
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) [![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@ -88,18 +89,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[24/12/21] 我们支持了 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-wb-面板)。 [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 [24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
<details><summary>展开日志</summary>
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
<details><summary>展开日志</summary>
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。 [24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。 [24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
@ -389,7 +390,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、swanlab、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@ -643,8 +644,7 @@ run_name: test_run # 可选
```yaml ```yaml
use_swanlab: true use_swanlab: true
swanlab_project: test_run # 可选 swanlab_run_name: test_run # 可选
swanlab_experiment_name: test_experiment # 可选
``` ```
在启动训练任务时登录SwanLab账户有以下三种方式 在启动训练任务时登录SwanLab账户有以下三种方式
@ -653,7 +653,6 @@ swanlab_experiment_name: test_experiment # 可选
方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。 方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。
方式三:启动前使用 `swanlab login` 命令完成登录。 方式三:启动前使用 `swanlab login` 命令完成登录。
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。 如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。

View File

@ -61,6 +61,7 @@ extra_require = {
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"openmind": ["openmind"], "openmind": ["openmind"],
"swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest"], "dev": ["pre-commit", "ruff", "pytest"],
} }

View File

@ -171,6 +171,9 @@ class HuggingfaceEngine(BaseEngine):
elif not isinstance(value, torch.Tensor): elif not isinstance(value, torch.Tensor):
value = torch.tensor(value) value = torch.tensor(value)
if torch.is_floating_point(value):
value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length

View File

@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
from typing import Literal, Optional from typing import Any, Dict, Literal, Optional
@dataclass @dataclass
@ -161,3 +161,6 @@ class DataArguments:
if self.mask_history and self.train_on_prompt: if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
from typing import List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
@dataclass @dataclass
@ -318,7 +318,7 @@ class SwanLabArguments:
default=None, default=None,
metadata={"help": "The workspace name in SwanLab."}, metadata={"help": "The workspace name in SwanLab."},
) )
swanlab_experiment_name: str = field( swanlab_run_name: str = field(
default=None, default=None,
metadata={"help": "The experiment name in SwanLab."}, metadata={"help": "The experiment name in SwanLab."},
) )
@ -440,3 +440,8 @@ class FinetuningArguments(
if self.pissa_init: if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.") raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import json import json
from dataclasses import dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@ -344,3 +344,8 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
setattr(result, name, value) setattr(result, name, value)
return result return result
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args

View File

@ -42,10 +42,13 @@ if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_file from safetensors.torch import save_file
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
@override @override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
@override @override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save: if args.should_save:
model = kwargs.pop("model") model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
) )
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.model_args = model_args
self.data_args = data_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not state.is_world_process_zero:
return
if "wandb" in args.report_to:
import wandb
wandb.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
if self.finetuning_args.use_swanlab:
import swanlab
swanlab.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)

View File

@ -30,8 +30,8 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46 from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING: if TYPE_CHECKING:
@ -97,18 +97,12 @@ class CustomDPOTrainer(DPOTrainer):
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@ -30,7 +30,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46 from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING: if TYPE_CHECKING:
@ -101,9 +101,6 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@ -40,7 +40,7 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@ -186,9 +186,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r""" r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.

View File

@ -20,8 +20,8 @@ from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -47,18 +47,12 @@ class CustomTrainer(Trainer):
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@ -26,8 +26,8 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -59,18 +59,12 @@ class PairwiseTrainer(Trainer):
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@ -28,8 +28,8 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -62,18 +62,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@ -472,9 +472,8 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
swanlab_callback = SwanLabCallback( swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project, project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace, workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name, experiment_name=finetuning_args.swanlab_run_name,
mode=finetuning_args.swanlab_mode, mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LLaMA Factory"}, config={"Framework": "🦙LlamaFactory"},
) )
return swanlab_callback return swanlab_callback

View File

@ -24,13 +24,14 @@ from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
from .kto import run_kto from .kto import run_kto
from .ppo import run_ppo from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_swanlab_callback
if TYPE_CHECKING: if TYPE_CHECKING:
@ -44,6 +45,14 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
callbacks.append(LogCallback()) callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback())
if finetuning_args.use_swanlab:
callbacks.append(get_swanlab_callback(finetuning_args))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt": if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft": elif finetuning_args.stage == "sft":

View File

@ -273,21 +273,23 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as swanlab_tab: with gr.Accordion(open=False) as swanlab_tab:
with gr.Row(): with gr.Row():
use_swanlab = gr.Checkbox() use_swanlab = gr.Checkbox()
swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True) swanlab_project = gr.Textbox(value="llamafactory")
swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) swanlab_run_name = gr.Textbox()
swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode}) input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
)
elem_dict.update( elem_dict.update(
dict( dict(
swanlab_tab=swanlab_tab, swanlab_tab=swanlab_tab,
use_swanlab=use_swanlab, use_swanlab=use_swanlab,
swanlab_api_key=swanlab_api_key,
swanlab_project=swanlab_project, swanlab_project=swanlab_project,
swanlab_run_name=swanlab_run_name,
swanlab_workspace=swanlab_workspace, swanlab_workspace=swanlab_workspace,
swanlab_experiment_name=swanlab_experiment_name, swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode, swanlab_mode=swanlab_mode,
) )
) )

View File

@ -1119,7 +1119,7 @@ LOCALES = {
"info": "Нормализация оценок в тренировке PPO.", "info": "Нормализация оценок в тренировке PPO.",
}, },
"zh": { "zh": {
"label": "奖励模型", "label": "归一化分数",
"info": "PPO 训练中归一化奖励分数。", "info": "PPO 训练中归一化奖励分数。",
}, },
"ko": { "ko": {
@ -1385,86 +1385,85 @@ LOCALES = {
"info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.", "info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
}, },
}, },
"swanlab_api_key": {
"en": {
"label": "API Key(optional)",
"info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.",
},
"ru": {
"label": "API ключ(Необязательный)",
"info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.",
},
"zh": {
"label": "API密钥(选填)",
"info": "用于在编程环境登录SwanLab已登录则无需填写。",
},
"ko": {
"label": "API 키(선택 사항)",
"info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.",
},
},
"swanlab_project": { "swanlab_project": {
"en": { "en": {
"label": "Project(optional)", "label": "SwanLab project",
}, },
"ru": { "ru": {
"label": "Проект(Необязательный)", "label": "SwanLab Проект",
}, },
"zh": { "zh": {
"label": "项目(选填)", "label": "SwanLab 项目名",
}, },
"ko": { "ko": {
"label": "프로젝트(선택 사항)", "label": "SwanLab 프로젝트",
},
},
"swanlab_run_name": {
"en": {
"label": "SwanLab experiment name (optional)",
},
"ru": {
"label": "SwanLab Имя эксперимента (опционально)",
},
"zh": {
"label": "SwanLab 实验名(非必填)",
},
"ko": {
"label": "SwanLab 실험 이름 (선택 사항)",
}, },
}, },
"swanlab_workspace": { "swanlab_workspace": {
"en": { "en": {
"label": "Workspace(optional)", "label": "SwanLab workspace (optional)",
"info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.", "info": "Workspace for SwanLab. Defaults to the personal workspace.",
}, },
"ru": { "ru": {
"label": "Рабочая область(Необязательный)", "label": "SwanLab Рабочая область (опционально)",
"info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.", "info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
}, },
"zh": { "zh": {
"label": "Workspace(选填)", "label": "SwanLab 工作区(非必填)",
"info": "SwanLab组织的工作区,如不填写则默认在个人工作区下", "info": "SwanLab 的工作区,默认在个人工作区下。",
}, },
"ko": { "ko": {
"label": "작업 영역(선택 사항)", "label": "SwanLab 작업 영역 (선택 사항)",
"info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.", "info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
}, },
}, },
"swanlab_experiment_name": { "swanlab_api_key": {
"en": { "en": {
"label": "Experiment name (optional)", "label": "SwanLab API key (optional)",
"info": "API key for SwanLab.",
}, },
"ru": { "ru": {
"label": "Имя эксперимента(Необязательный)", "label": "SwanLab API ключ (опционально)",
"info": "API ключ для SwanLab.",
}, },
"zh": { "zh": {
"label": "实验名(选填) ", "label": "SwanLab API密钥非必填",
"info": "用于在编程环境登录 SwanLab已登录则无需填写。",
}, },
"ko": { "ko": {
"label": "실험 이름(선택 사항)", "label": "SwanLab API 키 (선택 사항)",
"info": "SwanLab의 API 키.",
}, },
}, },
"swanlab_mode": { "swanlab_mode": {
"en": { "en": {
"label": "Mode", "label": "SwanLab mode",
"info": "Cloud or offline version.", "info": "Cloud or offline version.",
}, },
"ru": { "ru": {
"label": "Режим", "label": "SwanLab Режим",
"info": "Версия в облаке или локальная версия.", "info": "Версия в облаке или локальная версия.",
}, },
"zh": { "zh": {
"label": "模式", "label": "SwanLab 模式",
"info": "云端版或离线版", "info": "使用云端版或离线版 SwanLab。",
}, },
"ko": { "ko": {
"label": "모드", "label": "SwanLab 모드",
"info": "클라우드 버전 또는 오프라인 버전.", "info": "클라우드 버전 또는 오프라인 버전.",
}, },
}, },

View File

@ -231,13 +231,12 @@ class Runner:
# swanlab config # swanlab config
if get("train.use_swanlab"): if get("train.use_swanlab"):
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_project"] = get("train.swanlab_project") args["swanlab_project"] = get("train.swanlab_project")
args["swanlab_run_name"] = get("train.swanlab_run_name")
args["swanlab_workspace"] = get("train.swanlab_workspace") args["swanlab_workspace"] = get("train.swanlab_workspace")
args["swanlab_experiment_name"] = get("train.swanlab_experiment_name") args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_mode"] = get("train.swanlab_mode") args["swanlab_mode"] = get("train.swanlab_mode")
# eval config # eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo": if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size") args["val_size"] = get("train.val_size")