mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 11:20:35 +08:00
fix ppo train and dpo eval
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,7 +44,7 @@ class ModelArguments:
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
@@ -83,3 +83,6 @@ class ModelArguments:
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
Reference in New Issue
Block a user