mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support report custom args
Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
This commit is contained in:
		
							parent
							
								
									adff887659
								
							
						
					
					
						commit
						a897d46049
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -171,4 +171,5 @@ config/
 | 
			
		||||
saves/
 | 
			
		||||
output/
 | 
			
		||||
wandb/
 | 
			
		||||
swanlog/
 | 
			
		||||
generated_predictions.jsonl
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										14
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								README.md
									
									
									
									
									
								
							@ -4,7 +4,7 @@
 | 
			
		||||
[](LICENSE)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
 | 
			
		||||
[](https://pypi.org/project/llamafactory/)
 | 
			
		||||
[](#projects-using-llama-factory)
 | 
			
		||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
 | 
			
		||||
[](https://discord.gg/rKfvV9r9FK)
 | 
			
		||||
[](https://twitter.com/llamafactory_ai)
 | 
			
		||||
@ -13,6 +13,7 @@
 | 
			
		||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
 | 
			
		||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
 | 
			
		||||
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
 | 
			
		||||
[](https://gitcode.com/zhengyaowei/LLaMA-Factory)
 | 
			
		||||
 | 
			
		||||
[](https://trendshift.io/repositories/4535)
 | 
			
		||||
 | 
			
		||||
@ -87,18 +88,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[24/12/21] We supported **[SwanLab](https://github.com/SwanHubX/SwanLab)** experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
 | 
			
		||||
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
 | 
			
		||||
 | 
			
		||||
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
 | 
			
		||||
 | 
			
		||||
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
 | 
			
		||||
 | 
			
		||||
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
 | 
			
		||||
 | 
			
		||||
[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
 | 
			
		||||
@ -388,7 +389,7 @@ cd LLaMA-Factory
 | 
			
		||||
pip install -e ".[torch,metrics]"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
 | 
			
		||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, swanlab, quality
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
 | 
			
		||||
@ -642,8 +643,7 @@ To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental r
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
use_swanlab: true
 | 
			
		||||
swanlab_project: test_project # optional
 | 
			
		||||
swanlab_experiment_name: test_experiment # optional
 | 
			
		||||
swanlab_run_name: test_run # optional
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
When launching training tasks, you can log in to SwanLab in three ways:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								README_zh.md
									
									
									
									
									
								
							@ -4,7 +4,7 @@
 | 
			
		||||
[](LICENSE)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
 | 
			
		||||
[](https://pypi.org/project/llamafactory/)
 | 
			
		||||
[](#使用了-llama-factory-的项目)
 | 
			
		||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
 | 
			
		||||
[](https://discord.gg/rKfvV9r9FK)
 | 
			
		||||
[](https://twitter.com/llamafactory_ai)
 | 
			
		||||
@ -13,6 +13,7 @@
 | 
			
		||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
 | 
			
		||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
 | 
			
		||||
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
 | 
			
		||||
[](https://gitcode.com/zhengyaowei/LLaMA-Factory)
 | 
			
		||||
 | 
			
		||||
[](https://trendshift.io/repositories/4535)
 | 
			
		||||
 | 
			
		||||
@ -88,18 +89,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[24/12/21] 我们支持了 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-wb-面板)。 
 | 
			
		||||
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
 | 
			
		||||
 | 
			
		||||
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
 | 
			
		||||
 | 
			
		||||
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
 | 
			
		||||
 | 
			
		||||
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
 | 
			
		||||
@ -389,7 +390,7 @@ cd LLaMA-Factory
 | 
			
		||||
pip install -e ".[torch,metrics]"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
 | 
			
		||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、swanlab、quality
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
 | 
			
		||||
@ -643,8 +644,7 @@ run_name: test_run # 可选
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
use_swanlab: true
 | 
			
		||||
swanlab_project: test_run # 可选
 | 
			
		||||
swanlab_experiment_name: test_experiment # 可选
 | 
			
		||||
swanlab_run_name: test_run # 可选
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
在启动训练任务时,登录SwanLab账户有以下三种方式:
 | 
			
		||||
@ -653,7 +653,6 @@ swanlab_experiment_name: test_experiment # 可选
 | 
			
		||||
方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。
 | 
			
		||||
方式三:启动前使用 `swanlab login` 命令完成登录。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## 使用了 LLaMA Factory 的项目
 | 
			
		||||
 | 
			
		||||
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							@ -61,6 +61,7 @@ extra_require = {
 | 
			
		||||
    "qwen": ["transformers_stream_generator"],
 | 
			
		||||
    "modelscope": ["modelscope"],
 | 
			
		||||
    "openmind": ["openmind"],
 | 
			
		||||
    "swanlab": ["swanlab"],
 | 
			
		||||
    "dev": ["pre-commit", "ruff", "pytest"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -171,7 +171,10 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            elif not isinstance(value, torch.Tensor):
 | 
			
		||||
                value = torch.tensor(value)
 | 
			
		||||
 | 
			
		||||
            gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device)
 | 
			
		||||
            if torch.is_floating_point(value):
 | 
			
		||||
                value = value.to(model.dtype)
 | 
			
		||||
 | 
			
		||||
            gen_kwargs[key] = value.to(model.device)
 | 
			
		||||
 | 
			
		||||
        return gen_kwargs, prompt_length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,8 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Literal, Optional
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
from typing import Any, Dict, Literal, Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -161,3 +161,6 @@ class DataArguments:
 | 
			
		||||
 | 
			
		||||
        if self.mask_history and self.train_on_prompt:
 | 
			
		||||
            raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        return asdict(self)
 | 
			
		||||
 | 
			
		||||
@ -12,8 +12,8 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import List, Literal, Optional
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
from typing import Any, Dict, List, Literal, Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -318,7 +318,7 @@ class SwanLabArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The workspace name in SwanLab."},
 | 
			
		||||
    )
 | 
			
		||||
    swanlab_experiment_name: str = field(
 | 
			
		||||
    swanlab_run_name: str = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The experiment name in SwanLab."},
 | 
			
		||||
    )
 | 
			
		||||
@ -440,3 +440,8 @@ class FinetuningArguments(
 | 
			
		||||
 | 
			
		||||
            if self.pissa_init:
 | 
			
		||||
                raise ValueError("`pissa_init` is only valid for LoRA training.")
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        args = asdict(self)
 | 
			
		||||
        args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
 | 
			
		||||
        return args
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
from dataclasses import dataclass, field, fields
 | 
			
		||||
from dataclasses import asdict, dataclass, field, fields
 | 
			
		||||
from typing import Any, Dict, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -344,3 +344,8 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
 | 
			
		||||
            setattr(result, name, value)
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        args = asdict(self)
 | 
			
		||||
        args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
 | 
			
		||||
        return args
 | 
			
		||||
 | 
			
		||||
@ -42,10 +42,13 @@ if is_safetensors_available():
 | 
			
		||||
    from safetensors import safe_open
 | 
			
		||||
    from safetensors.torch import save_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import TrainerControl, TrainerState, TrainingArguments
 | 
			
		||||
    from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after a checkpoint save.
 | 
			
		||||
        """
 | 
			
		||||
        if args.should_save:
 | 
			
		||||
            output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
 | 
			
		||||
            fix_valuehead_checkpoint(
 | 
			
		||||
@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the beginning of training.
 | 
			
		||||
        """
 | 
			
		||||
        if args.should_save:
 | 
			
		||||
            model = kwargs.pop("model")
 | 
			
		||||
            pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
 | 
			
		||||
@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
                    remaining_time=self.remaining_time,
 | 
			
		||||
                )
 | 
			
		||||
                self.thread_pool.submit(self._write_log, args.output_dir, logs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReporterCallback(TrainerCallback):
 | 
			
		||||
    r"""
 | 
			
		||||
    A callback for reporting training status to external logger.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model_args: "ModelArguments",
 | 
			
		||||
        data_args: "DataArguments",
 | 
			
		||||
        finetuning_args: "FinetuningArguments",
 | 
			
		||||
        generating_args: "GeneratingArguments",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.model_args = model_args
 | 
			
		||||
        self.data_args = data_args
 | 
			
		||||
        self.finetuning_args = finetuning_args
 | 
			
		||||
        self.generating_args = generating_args
 | 
			
		||||
        os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        if not state.is_world_process_zero:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if "wandb" in args.report_to:
 | 
			
		||||
            import wandb
 | 
			
		||||
 | 
			
		||||
            wandb.config.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "model_args": self.model_args.to_dict(),
 | 
			
		||||
                    "data_args": self.data_args.to_dict(),
 | 
			
		||||
                    "finetuning_args": self.finetuning_args.to_dict(),
 | 
			
		||||
                    "generating_args": self.generating_args.to_dict(),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.finetuning_args.use_swanlab:
 | 
			
		||||
            import swanlab
 | 
			
		||||
 | 
			
		||||
            swanlab.config.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "model_args": self.model_args.to_dict(),
 | 
			
		||||
                    "data_args": self.data_args.to_dict(),
 | 
			
		||||
                    "finetuning_args": self.finetuning_args.to_dict(),
 | 
			
		||||
                    "generating_args": self.generating_args.to_dict(),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -30,8 +30,8 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46
 | 
			
		||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -97,18 +97,12 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            self.add_callback(SaveProcessorCallback(processor))
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.pissa_convert:
 | 
			
		||||
            self.callback_handler.add_callback(PissaConvertCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_badam:
 | 
			
		||||
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
 | 
			
		||||
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_swanlab:
 | 
			
		||||
            self.add_callback(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -30,7 +30,7 @@ from typing_extensions import override
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -101,9 +101,6 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_swanlab:
 | 
			
		||||
            self.add_callback(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ from typing_extensions import override
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
 | 
			
		||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -186,9 +186,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_swanlab:
 | 
			
		||||
            self.add_callback(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
 | 
			
		||||
 | 
			
		||||
@ -20,8 +20,8 @@ from transformers import Trainer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -47,18 +47,12 @@ class CustomTrainer(Trainer):
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            self.add_callback(SaveProcessorCallback(processor))
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.pissa_convert:
 | 
			
		||||
            self.add_callback(PissaConvertCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_badam:
 | 
			
		||||
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
 | 
			
		||||
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_swanlab:
 | 
			
		||||
            self.add_callback(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -26,8 +26,8 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
 | 
			
		||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -59,18 +59,12 @@ class PairwiseTrainer(Trainer):
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            self.add_callback(SaveProcessorCallback(processor))
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.pissa_convert:
 | 
			
		||||
            self.add_callback(PissaConvertCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_badam:
 | 
			
		||||
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
 | 
			
		||||
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_swanlab:
 | 
			
		||||
            self.add_callback(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -28,8 +28,8 @@ from typing_extensions import override
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -62,18 +62,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            self.add_callback(SaveProcessorCallback(processor))
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.pissa_convert:
 | 
			
		||||
            self.add_callback(PissaConvertCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_badam:
 | 
			
		||||
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
 | 
			
		||||
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_swanlab:
 | 
			
		||||
            self.add_callback(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -472,9 +472,8 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
 | 
			
		||||
    swanlab_callback = SwanLabCallback(
 | 
			
		||||
        project=finetuning_args.swanlab_project,
 | 
			
		||||
        workspace=finetuning_args.swanlab_workspace,
 | 
			
		||||
        experiment_name=finetuning_args.swanlab_experiment_name,
 | 
			
		||||
        experiment_name=finetuning_args.swanlab_run_name,
 | 
			
		||||
        mode=finetuning_args.swanlab_mode,
 | 
			
		||||
        config={"Framework": "🦙LLaMA Factory"},
 | 
			
		||||
        config={"Framework": "🦙LlamaFactory"},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return swanlab_callback
 | 
			
		||||
    return swanlab_callback
 | 
			
		||||
 | 
			
		||||
@ -24,13 +24,14 @@ from ..extras import logging
 | 
			
		||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ..hparams import get_infer_args, get_train_args
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
from .callbacks import LogCallback
 | 
			
		||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
 | 
			
		||||
from .dpo import run_dpo
 | 
			
		||||
from .kto import run_kto
 | 
			
		||||
from .ppo import run_ppo
 | 
			
		||||
from .pt import run_pt
 | 
			
		||||
from .rm import run_rm
 | 
			
		||||
from .sft import run_sft
 | 
			
		||||
from .trainer_utils import get_swanlab_callback
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -44,6 +45,14 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
 | 
			
		||||
    callbacks.append(LogCallback())
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.pissa_convert:
 | 
			
		||||
        callbacks.append(PissaConvertCallback())
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_swanlab:
 | 
			
		||||
        callbacks.append(get_swanlab_callback(finetuning_args))
 | 
			
		||||
 | 
			
		||||
    callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args))  # add to last
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.stage == "pt":
 | 
			
		||||
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
			
		||||
    elif finetuning_args.stage == "sft":
 | 
			
		||||
 | 
			
		||||
@ -273,21 +273,23 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
 | 
			
		||||
    with gr.Accordion(open=False) as swanlab_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            use_swanlab = gr.Checkbox()
 | 
			
		||||
            swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True)
 | 
			
		||||
            swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True)
 | 
			
		||||
            swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True)
 | 
			
		||||
            swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True)
 | 
			
		||||
            swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True)
 | 
			
		||||
            swanlab_project = gr.Textbox(value="llamafactory")
 | 
			
		||||
            swanlab_run_name = gr.Textbox()
 | 
			
		||||
            swanlab_workspace = gr.Textbox()
 | 
			
		||||
            swanlab_api_key = gr.Textbox()
 | 
			
		||||
            swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
 | 
			
		||||
 | 
			
		||||
    input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode})
 | 
			
		||||
    input_elems.update(
 | 
			
		||||
        {use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
 | 
			
		||||
    )
 | 
			
		||||
    elem_dict.update(
 | 
			
		||||
        dict(
 | 
			
		||||
            swanlab_tab=swanlab_tab,
 | 
			
		||||
            use_swanlab=use_swanlab,
 | 
			
		||||
            swanlab_api_key=swanlab_api_key,
 | 
			
		||||
            swanlab_project=swanlab_project,
 | 
			
		||||
            swanlab_run_name=swanlab_run_name,
 | 
			
		||||
            swanlab_workspace=swanlab_workspace,
 | 
			
		||||
            swanlab_experiment_name=swanlab_experiment_name,
 | 
			
		||||
            swanlab_api_key=swanlab_api_key,
 | 
			
		||||
            swanlab_mode=swanlab_mode,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -1385,86 +1385,85 @@ LOCALES = {
 | 
			
		||||
            "info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "swanlab_api_key": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "API Key(optional)",
 | 
			
		||||
            "info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "API ключ(Необязательный)",
 | 
			
		||||
            "info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "API密钥(选填)",
 | 
			
		||||
            "info": "用于在编程环境登录SwanLab,已登录则无需填写。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "API 키(선택 사항)",
 | 
			
		||||
            "info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "swanlab_project": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Project(optional)",
 | 
			
		||||
            "label": "SwanLab project",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Проект(Необязательный)",
 | 
			
		||||
            "label": "SwanLab Проект",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "项目(选填)",
 | 
			
		||||
            "label": "SwanLab 项目名",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "프로젝트(선택 사항)",
 | 
			
		||||
            "label": "SwanLab 프로젝트",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "swanlab_run_name": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "SwanLab experiment name (optional)",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "SwanLab Имя эксперимента (опционально)",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "SwanLab 实验名(非必填)",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "SwanLab 실험 이름 (선택 사항)",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "swanlab_workspace": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Workspace(optional)",
 | 
			
		||||
            "info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.",
 | 
			
		||||
            
 | 
			
		||||
            "label": "SwanLab workspace (optional)",
 | 
			
		||||
            "info": "Workspace for SwanLab. Defaults to the personal workspace.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Рабочая область(Необязательный)",
 | 
			
		||||
            "label": "SwanLab Рабочая область (опционально)",
 | 
			
		||||
            "info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "Workspace(选填)",
 | 
			
		||||
            "info": "SwanLab组织的工作区,如不填写则默认在个人工作区下",
 | 
			
		||||
            "label": "SwanLab 工作区(非必填)",
 | 
			
		||||
            "info": "SwanLab 的工作区,默认在个人工作区下。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "작업 영역(선택 사항)",
 | 
			
		||||
            "label": "SwanLab 작업 영역 (선택 사항)",
 | 
			
		||||
            "info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "swanlab_experiment_name": {
 | 
			
		||||
    "swanlab_api_key": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Experiment name (optional)",
 | 
			
		||||
            "label": "SwanLab API key (optional)",
 | 
			
		||||
            "info": "API key for SwanLab.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Имя эксперимента(Необязательный)",
 | 
			
		||||
            "label": "SwanLab API ключ (опционально)",
 | 
			
		||||
            "info": "API ключ для SwanLab.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "实验名(选填)  ",
 | 
			
		||||
            "label": "SwanLab API密钥(非必填)",
 | 
			
		||||
            "info": "用于在编程环境登录 SwanLab,已登录则无需填写。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "실험 이름(선택 사항)",
 | 
			
		||||
            "label": "SwanLab API 키 (선택 사항)",
 | 
			
		||||
            "info": "SwanLab의 API 키.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "swanlab_mode": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Mode",
 | 
			
		||||
            "info": "Cloud or offline version.",    
 | 
			
		||||
            "label": "SwanLab mode",
 | 
			
		||||
            "info": "Cloud or offline version.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Режим",
 | 
			
		||||
            "label": "SwanLab Режим",
 | 
			
		||||
            "info": "Версия в облаке или локальная версия.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "模式",
 | 
			
		||||
            "info": "云端版或离线版",
 | 
			
		||||
            "label": "SwanLab 模式",
 | 
			
		||||
            "info": "使用云端版或离线版 SwanLab。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "모드",
 | 
			
		||||
            "label": "SwanLab 모드",
 | 
			
		||||
            "info": "클라우드 버전 또는 오프라인 버전.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
@ -231,12 +231,11 @@ class Runner:
 | 
			
		||||
 | 
			
		||||
        # swanlab config
 | 
			
		||||
        if get("train.use_swanlab"):
 | 
			
		||||
            args["swanlab_api_key"] = get("train.swanlab_api_key")
 | 
			
		||||
            args["swanlab_project"] = get("train.swanlab_project")
 | 
			
		||||
            args["swanlab_run_name"] = get("train.swanlab_run_name")
 | 
			
		||||
            args["swanlab_workspace"] = get("train.swanlab_workspace")
 | 
			
		||||
            args["swanlab_experiment_name"] = get("train.swanlab_experiment_name")
 | 
			
		||||
            args["swanlab_api_key"] = get("train.swanlab_api_key")
 | 
			
		||||
            args["swanlab_mode"] = get("train.swanlab_mode")
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
        # eval config
 | 
			
		||||
        if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user