fix ppo train and dpo eval

This commit is contained in:
hiyouga
2023-11-07 22:48:51 +08:00
parent 11c1e1e157
commit 01260d9754
5 changed files with 56 additions and 21 deletions

View File

@@ -1,5 +1,5 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
@@ -44,7 +44,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
)
flash_attn: Optional[bool] = field(
default=False,
@@ -83,3 +83,6 @@ class ModelArguments:
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)