diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 0f28ce1e..ad481ffd 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -30,7 +30,7 @@ def preprocess_dataset( yield query, response, history, prefix def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build grouped texts with format ` X1 X2 X3 ...` (without ) + # build grouped texts with format `X1 X2 X3 ...` (without ) tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) @@ -55,17 +55,17 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): # TODO: fix bos + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] - if len(target_ids) > data_args.max_target_length - 1: # eos token - target_ids = target_ids[:data_args.max_target_length - 1] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[:data_args.max_target_length] - if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: + if len(input_ids) + len(source_ids) + len(target_ids) > max_length: break - input_ids += source_ids + target_ids + [tokenizer.eos_token_id] - labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] + input_ids += source_ids + target_ids + labels += [IGNORE_INDEX] * len(source_ids) + target_ids model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 4ffd569a..4f8e6301 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -29,7 +29,7 @@ class Template: encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id] + prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids = prompt_ids + encoded_pairs[-1][0] return prompt_ids, encoded_pairs[-1][1] @@ -73,6 +73,11 @@ class Template: r""" Encodes formatted inputs to pairs of token ids. """ + if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional + bos_token_id = [tokenizer.bos_token_id] + else: + bos_token_id = [] + eos_token_id = [tokenizer.eos_token_id] # eos token is required encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): if turn_idx == 0: @@ -81,7 +86,7 @@ class Template: prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id)) return encoded_pairs def _convert_inputs_to_ids(