mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -1,25 +1,25 @@
|
||||
import tiktoken
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
from itertools import chain
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: "Dataset",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> "Dataset":
|
||||
column_names = list(dataset.column_names)
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
|
||||
Reference in New Issue
Block a user