From 9337568e230735390b8331d5a33b2b0f542d7ab0 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 20 Dec 2023 19:06:43 +0800 Subject: [PATCH] fix stop words Former-commit-id: dec360d5aee58c11804b8e45dd8d4c375086887f --- src/llmtuner/data/template.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 6d57698d..d4ec22d5 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -1,4 +1,5 @@ import tiktoken +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union @@ -223,19 +224,20 @@ def get_template_and_fix_tokenizer( template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) + stop_words = deepcopy(template.stop_words) if template.replace_eos: - if not template.stop_words: + if not stop_words: raise ValueError("Stop words are required to replace the EOS token.") - tokenizer.eos_token = template.stop_words.pop(0) + tokenizer.eos_token = stop_words.pop(0) logger.info("Replace eos token: {}".format(tokenizer.eos_token)) - if template.stop_words: + if stop_words: tokenizer.add_special_tokens( - dict(additional_special_tokens=template.stop_words), + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False ) - logger.info("Add {} to stop words.".format(",".join(template.stop_words))) + logger.info("Add {} to stop words.".format(",".join(stop_words))) return template