use fp16 model, add logcallback

Former-commit-id: bea275d51338b49ce855eec0178e759607265e3d
This commit is contained in:
hiyouga
2023-05-28 21:30:28 +08:00
parent 54574f1dfa
commit a4384e442c
7 changed files with 112 additions and 10 deletions

View File

@@ -1,8 +1,18 @@
import os
import json
import time
import torch
from typing import Dict, Optional
from datetime import timedelta
from transformers import (
Seq2SeqTrainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
@@ -23,6 +33,44 @@ from .other import (
logger = get_logger(__name__)
class LogCallback(TrainerCallback):
r"""
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
purposes.
"""
def __init__(self):
self.start_time = time.time()
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
log_dict = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
f.write(json.dumps(log_dict) + "\n")
class PeftTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
@@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""