diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 4be79bad..0562d303 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -31,7 +31,11 @@ def preprocess_dataset( def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build grouped texts with format `X1 X2 X3 ...` (without ) - tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) + if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=False) + tokenized_examples = tokenizer(examples["prompt"], **kwargs) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) block_size = data_args.max_source_length diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 0d114b43..6463c3d5 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -59,11 +59,26 @@ class Template: Aligns inputs to a special format. """ prefix = [prefix] if prefix else self.prefix # use prefix if provided - prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix history = history if (history and self.use_history) else [] history = history + [(query, resp)] return prefix, history + def _get_special_ids( + self, + tokenizer: "PreTrainedTokenizer" + ) -> Tuple[List[int], List[int]]: + if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", False): + bos_ids = [tokenizer.bos_token_id] + else: # bos token is optional + bos_ids = [] + + if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False): + eos_ids = [tokenizer.eos_token_id] + else: # use the first stop word as the eos token + eos_ids = tokenizer.convert_tokens_to_ids(self.stop_words[0]) + + return bos_ids, eos_ids + def _encode( self, tokenizer: "PreTrainedTokenizer", @@ -73,20 +88,17 @@ class Template: r""" Encodes formatted inputs to pairs of token ids. """ - if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", False): # bos token is optional - bos_token_id = [tokenizer.bos_token_id] - else: - bos_token_id = [] - eos_token_id = [tokenizer.eos_token_id] # eos token is required + bos_ids, eos_ids = self._get_special_ids(tokenizer) + sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): if turn_idx == 0: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids else: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + prefix_ids = sep_ids query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id)) + encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids)) return encoded_pairs def _convert_inputs_to_ids( @@ -127,22 +139,15 @@ class Llama2Template(Template): r""" Encodes formatted inputs to pairs of token ids. """ - if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional - bos_token_id = [tokenizer.bos_token_id] - else: - bos_token_id = [] - eos_token_id = [tokenizer.eos_token_id] # eos token is required + bos_ids, eos_ids = self._get_special_ids(tokenizer) encoded_pairs = [] assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0: - prefix_ids = [] + if turn_idx == 0: # llama2 template has not sep_ids query = prefix[0] + query - else: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id)) + encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) return encoded_pairs @@ -226,11 +231,10 @@ register_template( "If you don't know the answer to a question, please don't share false information.\n<>\n\n" ], prompt=[ + {"token": ""}, "[INST] {{query}} [/INST] " ], - sep=[ - {"token": ""} - ], + sep=[], stop_words=[], use_history=True ) @@ -382,7 +386,6 @@ register_template( ":" ], sep=[ - {"token": ""}, "\n" ], stop_words=[ @@ -427,7 +430,6 @@ register_template( {"token": "<|assistant|>"} ], sep=[ - {"token": "<|end|>"}, "\n" ], stop_words=[ @@ -455,7 +457,6 @@ register_template( "assistant\n" ], sep=[ - {"token": "<|im_end|>"}, "\n" ], stop_words=[ diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index ec19e962..690a6c80 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -68,9 +68,7 @@ def load_model_and_tokenizer( padding_side=model_args.padding_side, **config_kwargs ) - if tokenizer.eos_token_id is None: # fix qwen tokenizer - tokenizer.eos_token = "<|endoftext|>" - if tokenizer.pad_token_id is None: # add pad token + if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: # add pad token tokenizer.pad_token = tokenizer.eos_token if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":