fix ppo train and dpo eval

This commit is contained in:
hiyouga
2023-11-07 22:48:51 +08:00
parent 11c1e1e157
commit 01260d9754
5 changed files with 56 additions and 21 deletions

View File

@@ -75,6 +75,14 @@ class FinetuningArguments:
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
dpo_ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the DPO training."}
)
dpo_ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@@ -91,7 +99,7 @@ class FinetuningArguments:
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""