use fp16 model, add logcallback

This commit is contained in:
hiyouga
2023-05-28 21:30:28 +08:00
parent 769c6ab56b
commit 0c9fda01e3
7 changed files with 112 additions and 10 deletions

View File

@@ -1,8 +1,18 @@
import os
import json
import time
import torch
from typing import Dict, Optional
from datetime import timedelta
from transformers import (
Seq2SeqTrainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
@@ -23,6 +33,44 @@ from .other import (
logger = get_logger(__name__)
class LogCallback(TrainerCallback):
r"""
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
purposes.
"""
def __init__(self):
self.start_time = time.time()
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
log_dict = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
f.write(json.dumps(log_dict) + "\n")
class PeftTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
@@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""