From 7a89fce4c7ee3d35a4dbc43f5a3029537d9c0961 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 5 Aug 2023 00:07:54 +0800 Subject: [PATCH] fix llama2 template Former-commit-id: e4a15f863c879f28a716a90f7c928ac02f059b6e --- src/llmtuner/extras/template.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 4f8e6301..0d6eb57c 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -58,7 +58,7 @@ class Template: r""" Aligns inputs to a special format. """ - prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided + prefix = [prefix] if prefix else self.prefix # use prefix if provided prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix history = history if (history and self.use_history) else [] history = history + [(query, resp)] @@ -124,6 +124,11 @@ class Llama2Template(Template): r""" Encodes formatted inputs to pairs of token ids. """ + if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional + bos_token_id = [tokenizer.bos_token_id] + else: + bos_token_id = [] + eos_token_id = [tokenizer.eos_token_id] # eos token is required encoded_pairs = [] assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." for turn_idx, (query, resp) in enumerate(history): @@ -134,7 +139,7 @@ class Llama2Template(Template): prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id)) return encoded_pairs @@ -154,8 +159,8 @@ def register_template( prefix=prefix, prompt=prompt, sep=sep, - use_history=use_history, - stop_words=stop_words + stop_words=stop_words, + use_history=use_history )