mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
fix dpo metrics
This commit is contained in:
@@ -287,13 +287,13 @@ class LogCallback(TrainerCallback):
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
loss=state.log_history[-1].get("loss", None),
|
||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||
reward=state.log_history[-1].get("reward", None),
|
||||
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||
epoch=state.log_history[-1].get("epoch", None),
|
||||
loss=state.log_history[-1].get("loss"),
|
||||
eval_loss=state.log_history[-1].get("eval_loss"),
|
||||
predict_loss=state.log_history[-1].get("predict_loss"),
|
||||
reward=state.log_history[-1].get("reward"),
|
||||
accuracy=state.log_history[-1].get("rewards/accuracies"),
|
||||
lr=state.log_history[-1].get("learning_rate"),
|
||||
epoch=state.log_history[-1].get("epoch"),
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
@@ -304,16 +304,17 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
|
||||
vram_allocated, vram_reserved = get_peak_memory()
|
||||
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
|
||||
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
|
||||
logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
|
||||
logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
|
||||
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info_rank0(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
|
||||
)
|
||||
)
|
||||
if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
|
||||
log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
|
||||
for extra_key in ("reward", "accuracy", "throughput"):
|
||||
if logs.get(extra_key):
|
||||
log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
|
||||
|
||||
logger.info_rank0("{" + log_str + "}")
|
||||
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
Reference in New Issue
Block a user