From 8b4ef062b7cfb4fb4a765b73f55f737b0faa1dd7 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 7 Jul 2023 11:02:28 +0800 Subject: [PATCH] support InternLM Former-commit-id: a2f507c56238d7fb2670edbab52d6b275f245e27 --- README.md | 2 ++ src/utils/seq2seq.py | 4 ++-- src/utils/template.py | 11 +++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 452f2843..1a5ac9ce 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ ## Changelog +[23/07/07] Now we support training the InternLM-7B model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when using the InternLM-chat model. + [23/07/05] Now we support training the Falcon-7B/40B models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model. [23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details. diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index cab7a919..f9a9dc6b 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -104,8 +104,8 @@ class Seq2SeqPeftTrainer(PeftTrainer): preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) - decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) - decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] diff --git a/src/utils/template.py b/src/utils/template.py index b4fadddb..87d9a31b 100644 --- a/src/utils/template.py +++ b/src/utils/template.py @@ -114,6 +114,17 @@ class Template: use_history=True ) + elif self.name == "intern": + r""" + Supports: https://huggingface.co/internlm/internlm-chat-7b + """ + self._register_template( + prefix="", + prompt="<|User|>:{query}\n<|Bot|>:", + sep="\n", + use_history=True + ) + else: raise ValueError("Template {} does not exist.".format(self.name))