code refactor

Former-commit-id: 40627c601e
This commit is contained in:
Ting
2024-11-19 20:33:18 +08:00
parent 32656bc50d
commit e27a0c3d53
4 changed files with 29 additions and 22 deletions

View File

@@ -17,10 +17,9 @@
from typing import TYPE_CHECKING, List, Optional
import torch.distributed as dist
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
@@ -67,9 +66,10 @@ def run_dpo(
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
effective_token_num = 0.0
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
# Initialize our Trainer
trainer = CustomDPOTrainer(
@@ -86,12 +86,10 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
train_result.metrics["effective_tokens_per_sec"] = (
effective_token_num * train_result.metrics["epoch"] / train_result.metrics["train_runtime"]
)
if dist.is_initialized():
train_result.metrics["effective_tokens_per_sec"] = (
train_result.metrics["effective_tokens_per_sec"] / dist.get_world_size()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
)
trainer.save_model()