mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
use fp16 model, add logcallback
This commit is contained in:
@@ -1,8 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
Seq2SeqTrainer,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
@@ -23,6 +33,44 @@ from .other import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
|
||||
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
|
||||
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
|
||||
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
|
||||
purposes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
log_dict = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
|
||||
f.write(json.dumps(log_dict) + "\n")
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
@@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user