From ce9ffca0d98eb3d8050c22473f9e32dbc9c17c4d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 9 Aug 2023 23:10:20 +0800 Subject: [PATCH] fix template Former-commit-id: ac29f4d5f0d9d514aec9224fd751b9eb49430e7e --- src/llmtuner/extras/template.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 413333ac..dd88782d 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -97,10 +97,10 @@ class Template: sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0 and prefix: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids - else: + if turn_idx != 0: prefix_ids = sep_ids + elif prefix: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids))