diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 6108b245..3487b761 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -6,7 +6,6 @@ from ..extras.constants import IGNORE_INDEX from ..extras.logging import get_logger from .utils import Role - if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer @@ -18,7 +17,7 @@ logger = get_logger(__name__) def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] @@ -35,7 +34,7 @@ def preprocess_pretrain_dataset( block_size = data_args.cutoff_len total_length = (total_length // block_size) * block_size result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } if data_args.template == "gemma": @@ -46,11 +45,11 @@ def preprocess_pretrain_dataset( def preprocess_supervised_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", - processor: "AutoProcessor" = None, + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", + processor: "AutoProcessor" = None, ) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. @@ -63,14 +62,14 @@ def preprocess_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) + template.encode_multiturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) ): if data_args.train_on_prompt: source_mask = source_ids @@ -96,10 +95,10 @@ def preprocess_supervised_dataset( def preprocess_packed_supervised_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` @@ -111,7 +110,7 @@ def preprocess_packed_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] + tokenizer, messages, examples["system"][i], examples["tools"][i] ): if data_args.train_on_prompt: source_mask = source_ids @@ -133,19 +132,19 @@ def preprocess_packed_supervised_dataset( total_length = (total_length // block_size) * block_size # split by chunks of cutoff_len for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) + if not all(label == IGNORE_INDEX for label in labels[i: i + block_size]): + model_inputs["input_ids"].append(input_ids[i: i + block_size]) model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) + model_inputs["labels"].append(labels[i: i + block_size]) return model_inputs def preprocess_unsupervised_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} @@ -179,10 +178,10 @@ def preprocess_unsupervised_dataset( def preprocess_pairwise_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} @@ -246,12 +245,12 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: def get_preprocess_and_print_func( - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], - processor: Optional["AutoProcessor"] = None, + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], + processor: Optional["AutoProcessor"] = None, ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) @@ -280,5 +279,4 @@ def get_preprocess_and_print_func( preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - return preprocess_func, print_function