mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
38 lines
1.6 KiB
Python
38 lines
1.6 KiB
Python
from itertools import chain
|
|
from typing import TYPE_CHECKING, Any, Dict, List
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
|
|
from ...hparams import DataArguments
|
|
|
|
|
|
def preprocess_pretrain_dataset(
|
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
|
) -> Dict[str, List[List[int]]]:
|
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
|
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
|
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
|
|
|
if not data_args.packing:
|
|
if data_args.template == "gemma":
|
|
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
|
|
|
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
|
else:
|
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
|
block_size = data_args.cutoff_len
|
|
total_length = (total_length // block_size) * block_size
|
|
result = {
|
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
for k, t in concatenated_examples.items()
|
|
}
|
|
if data_args.template == "gemma":
|
|
for i in range(len(result["input_ids"])):
|
|
result["input_ids"][i][0] = tokenizer.bos_token_id
|
|
|
|
return result
|