Files
LLaMA-Factory/src/train_bash.py
2023-07-15 16:54:28 +08:00

24 lines
705 B
Python

from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
def main():
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()