mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 08:02:51 +08:00
Update data_args.py
Former-commit-id: cba673f491c5d97aba62aea03f310bd54fb3fe28
This commit is contained in:
parent
30a3c6e886
commit
788dc1c679
@ -31,12 +31,11 @@ class DataArguments:
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use for eval during training. "
|
||||
"Use commas to separate multiple datasets."},
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
@ -110,12 +109,33 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
)
|
||||
eval_tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized eval datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.dataset = split_arg(self.dataset)
|
||||
self.eval_dataset = split_arg(self.eval_dataset)
|
||||
|
||||
if self.dataset is None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||
|
||||
if self.eval_dataset is not None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
if self.mix_strategy == "concat":
|
||||
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
|
||||
|
||||
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
|
||||
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
|
||||
raise ValueError("The length of dataset and interleave probs should be identical.")
|
||||
|
||||
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
|
||||
raise ValueError("The length of eval dataset and interleave probs should be identical.")
|
||||
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user