mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 04:40:35 +08:00
support report custom args
This commit is contained in:
@@ -42,10 +42,13 @@ if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
fix_valuehead_checkpoint(
|
||||
@@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
@@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for reporting training status to external logger.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.model_args = model_args
|
||||
self.data_args = data_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
wandb.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab
|
||||
|
||||
swanlab.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -30,8 +30,8 @@ from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -97,18 +97,12 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.callback_handler.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -30,7 +30,7 @@ from typing_extensions import override
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -101,9 +101,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -40,7 +40,7 @@ from typing_extensions import override
|
||||
from ...extras import logging
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
|
||||
@@ -186,9 +186,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
|
||||
@@ -20,8 +20,8 @@ from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,18 +47,12 @@ class CustomTrainer(Trainer):
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -26,8 +26,8 @@ from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -59,18 +59,12 @@ class PairwiseTrainer(Trainer):
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -28,8 +28,8 @@ from typing_extensions import override
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -62,18 +62,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -472,9 +472,8 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
swanlab_callback = SwanLabCallback(
|
||||
project=finetuning_args.swanlab_project,
|
||||
workspace=finetuning_args.swanlab_workspace,
|
||||
experiment_name=finetuning_args.swanlab_experiment_name,
|
||||
experiment_name=finetuning_args.swanlab_run_name,
|
||||
mode=finetuning_args.swanlab_mode,
|
||||
config={"Framework": "🦙LLaMA Factory"},
|
||||
config={"Framework": "🦙LlamaFactory"},
|
||||
)
|
||||
|
||||
return swanlab_callback
|
||||
return swanlab_callback
|
||||
|
||||
@@ -24,13 +24,14 @@ from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .trainer_utils import get_swanlab_callback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -44,6 +45,14 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
callbacks.append(LogCallback())
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
callbacks.append(PissaConvertCallback())
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
callbacks.append(get_swanlab_callback(finetuning_args))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
|
||||
Reference in New Issue
Block a user