mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
support ORPO
This commit is contained in:
@@ -2,13 +2,12 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..utils import create_modelcard_and_push, create_ref_model
|
||||
from .collator import DPODataCollatorWithPadding
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
@@ -29,7 +28,7 @@ def run_dpo(
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
@@ -64,7 +63,7 @@ def run_dpo(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
||||
Reference in New Issue
Block a user