diff --git a/README.md b/README.md index bc7b3c4b..a0df0e2c 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml | -| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | @@ -67,6 +66,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 > **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules. > > For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models. +> +> Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported. ## Supported Training Approaches @@ -443,7 +444,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) +Please follow the model licenses to use the corresponding model weights: [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) ## Citation diff --git a/README_zh.md b/README_zh.md index c0722d5c..7caf80c6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -59,7 +59,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml | -| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | @@ -67,6 +66,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 > **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。 > > 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。 +> +> 项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。 ## 训练方法 @@ -442,7 +443,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) +使用模型权重时,请遵循对应的模型协议:[LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) ## 引用 diff --git a/src/evaluate.py b/src/evaluate.py index 89f170be..a198ed09 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -84,10 +84,12 @@ def batch_inference( prefix_char: str ) -> List[str]: logits = chat_model.model(**batch_input).logits + lengths = torch.sum(batch_input["attention_mask"], dim=-1) + nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) probs = torch.nn.functional.softmax( torch.stack( [ - logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]] + nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]] for choice in choices ], dim=-1 @@ -120,8 +122,8 @@ def evaluate( checkpoint_dir=checkpoint_dir, template=template )) + chat_model.tokenizer.padding_side = "left" # avoid overflow issue in batched inference for llama2 eval_template = eval_templates[lang] - assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted." category_corrects: Dict[str, np.ndarray] = { subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"] diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index ae486c69..9e88768e 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -289,8 +289,8 @@ register_template( r""" -Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 - https://huggingface.co/ziqingyang/chinese-alpaca-2-7b +Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b + https://huggingface.co/ziqingyang/chinese-alpaca-2-13b """ register_template( name="llama2_zh", @@ -307,7 +307,6 @@ register_template( r""" Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff - https://github.com/ymcui/Chinese-LLaMA-Alpaca """ register_template( name="alpaca", @@ -328,8 +327,8 @@ register_template( r""" -Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 - https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 +Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5 + https://huggingface.co/lmsys/vicuna-13b-v1.5 """ register_template( name="vicuna", @@ -365,44 +364,9 @@ register_template( ) -r""" -Supports: https://github.com/CVI-SZU/Linly -""" -register_template( - name="linly", - prefix=[ - "{{system}}" - ], - prompt=[ - "User: {{query}}\nBot: " - ], - system="", - sep=[ - "\n" - ] -) - - -r""" -Supports: https://github.com/Neutralzz/BiLLa -""" -register_template( - name="billa", - prefix=[ - "{{system}}" - ], - prompt=[ - "Human: {{query}}\nAssistant: " - ], - system="", - sep=[ - "\n" - ] -) - - r""" Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 + https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat """ register_template( name="ziya", @@ -424,6 +388,8 @@ register_template( r""" Supports: https://huggingface.co/BAAI/AquilaChat-7B + https://huggingface.co/BAAI/AquilaChat2-7B + https://huggingface.co/BAAI/AquilaChat2-34B """ register_template( name="aquila", @@ -449,6 +415,7 @@ register_template( r""" Supports: https://huggingface.co/internlm/internlm-chat-7b + https://huggingface.co/internlm/internlm-chat-20b """ register_template( name="intern", @@ -542,6 +509,7 @@ register_template( r""" Supports: https://huggingface.co/Qwen/Qwen-7B-Chat + https://huggingface.co/Qwen/Qwen-14B-Chat """ register_template( name="chatml", @@ -591,7 +559,29 @@ register_template( r""" -Supports: https://huggingface.co/xverse/XVERSE-13B-Chat +Supports: https://huggingface.co/openchat/openchat_v3.2_super +""" +register_template( + name="openchat", + prefix=[ + "{{system}}" + ], + prompt=[ + "GPT4 User: {{query}}", + {"token": "<|end_of_turn|>"}, + "GPT4 Assistant: " + ], + system="", + sep=[ + {"token": "<|end_of_turn|>"} + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/xverse/XVERSE-7B-Chat + https://huggingface.co/xverse/XVERSE-13B-Chat """ register_template( name="xverse", diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 839dec8f..ff246c00 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -113,6 +113,8 @@ class DataArguments: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) except Exception: + if self.dataset is not None: + raise ValueError("Cannot find dataset_info.json in `dataset_dir`.") dataset_info = None prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]