from itertools import chain from typing import TYPE_CHECKING, Any, Dict, List if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer from ...hparams import DataArguments def preprocess_pretrain_dataset( examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] if not data_args.packing: if data_args.template == "gemma": text_examples = [tokenizer.bos_token + example for example in text_examples] result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True) else: tokenized_examples = tokenizer(text_examples, add_special_tokens=False) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) block_size = data_args.cutoff_len total_length = (total_length // block_size) * block_size result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } if data_args.template == "gemma": for i in range(len(result["input_ids"])): result["input_ids"][i][0] = tokenizer.bos_token_id return result