mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-22 06:40:35 +08:00
@@ -31,7 +31,7 @@ class DataArguments:
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_en",
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
@@ -46,13 +46,17 @@ class DataArguments:
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
)
|
||||
train_on_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
metadata={"help": "Enable dataset streaming."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
@@ -95,10 +99,20 @@ class DataArguments:
|
||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import torch
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -19,6 +18,10 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
@@ -76,6 +79,9 @@ class ModelArguments:
|
||||
self.compute_dtype = None
|
||||
self.model_max_length = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user