From 5ef2b8bddabc56bc74757792f8fb5f5ec50c691d Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Thu, 25 Apr 2024 22:40:53 +0800 Subject: [PATCH] modify some style Former-commit-id: ece78a6d6af0673795824b2f95c266c042532eb3 --- src/llmtuner/data/preprocess.py | 75 +++++++++++++++++---------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 3487b761..59b49b9d 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -6,6 +6,7 @@ from ..extras.constants import IGNORE_INDEX from ..extras.logging import get_logger from .utils import Role + if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer @@ -17,7 +18,7 @@ logger = get_logger(__name__) def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] @@ -34,7 +35,7 @@ def preprocess_pretrain_dataset( block_size = data_args.cutoff_len total_length = (total_length // block_size) * block_size result = { - k: [t[i: i + block_size] for i in range(0, total_length, block_size)] + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } if data_args.template == "gemma": @@ -45,11 +46,11 @@ def preprocess_pretrain_dataset( def preprocess_supervised_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", - processor: "AutoProcessor" = None, + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", + processor: "AutoProcessor" = None, ) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. @@ -62,14 +63,14 @@ def preprocess_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) + template.encode_multiturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) ): if data_args.train_on_prompt: source_mask = source_ids @@ -95,10 +96,10 @@ def preprocess_supervised_dataset( def preprocess_packed_supervised_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` @@ -110,7 +111,7 @@ def preprocess_packed_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] + tokenizer, messages, examples["system"][i], examples["tools"][i] ): if data_args.train_on_prompt: source_mask = source_ids @@ -132,19 +133,19 @@ def preprocess_packed_supervised_dataset( total_length = (total_length // block_size) * block_size # split by chunks of cutoff_len for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i: i + block_size]): - model_inputs["input_ids"].append(input_ids[i: i + block_size]) + if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): + model_inputs["input_ids"].append(input_ids[i : i + block_size]) model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i: i + block_size]) + model_inputs["labels"].append(labels[i : i + block_size]) return model_inputs def preprocess_unsupervised_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} @@ -178,10 +179,10 @@ def preprocess_unsupervised_dataset( def preprocess_pairwise_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} @@ -245,12 +246,12 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: def get_preprocess_and_print_func( - tokenizer: "PreTrainedTokenizer", - template: "Template", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], - processor: Optional["AutoProcessor"] = None, + tokenizer: "PreTrainedTokenizer", + template: "Template", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], + processor: Optional["AutoProcessor"] = None, ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)