mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Update loader.py
Former-commit-id: 3408af236f0b4ef64c3bfa791ef757828a74da7f
This commit is contained in:
parent
268c0efd67
commit
4487a4a790
@ -16,12 +16,13 @@ from .utils import checksum, merge_dataset
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import AutoProcessor, Seq2SeqTrainingArguments
|
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -114,12 +115,12 @@ def load_single_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
processor: Optional["AutoProcessor"] = None,
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
@ -149,7 +150,7 @@ def get_dataset(
|
|||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
tokenizer, template, data_args, training_args, stage, processor
|
data_args, training_args, stage, template, tokenizer, processor
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user