fix bug in PPO training

Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
This commit is contained in:
hiyouga
2023-11-16 02:32:54 +08:00
parent 77b1ed4deb
commit 71fe9ccdd4
3 changed files with 7 additions and 4 deletions

View File

@@ -95,9 +95,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")