This commit is contained in:
hiyouga
2023-11-14 18:07:20 +08:00
parent 3743b7420b
commit d125ef5535
2 changed files with 32 additions and 22 deletions

View File

@@ -52,6 +52,10 @@ class DataArguments:
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
@@ -110,6 +114,9 @@ class DataArguments:
)
def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")