Former-commit-id: ec51986cf70b0bdd79b8141e45916670fb97a08e
This commit is contained in:
enji.zhou
2024-05-17 13:09:17 +08:00
parent 92b3697e2c
commit 66b5634ebf
12 changed files with 452 additions and 8 deletions

View File

@@ -14,7 +14,7 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .kto import run_kto
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -39,6 +39,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")