diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index b331de35..9e2b6a20 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -47,9 +47,13 @@ def preprocess_dataset( kwargs = dict(add_special_tokens=True) if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer + add_eos_token_flag = getattr(tokenizer, "add_eos_token") setattr(tokenizer, "add_eos_token", True) tokenized_examples = tokenizer(examples["prompt"], **kwargs) + # Make sure the saved tokenizer is the same as the original + if hasattr(tokenizer, "add_eos_token"): # for Baichuan2 tokenizer + setattr(tokenizer, "add_eos_token", add_eos_token_flag) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) block_size = data_args.cutoff_len