Former-commit-id: b665e9e133bf2f6f10346c374eb0de8a96dd5c7e
This commit is contained in:
hiyouga 2023-10-20 23:28:52 +08:00
parent 1712a59280
commit 95697652f1
5 changed files with 44 additions and 48 deletions

View File

@ -59,7 +59,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
@ -67,6 +66,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules. > **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
> >
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models. > For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
>
> Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
## Supported Training Approaches ## Supported Training Approaches
@ -443,7 +444,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) Please follow the model licenses to use the corresponding model weights: [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
## Citation ## Citation

View File

@ -59,7 +59,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
@ -67,6 +66,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。 > **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
> >
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
>
> 项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
## 训练方法 ## 训练方法
@ -442,7 +443,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) 使用模型权重时,请遵循对应的模型协议:[LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
## 引用 ## 引用

View File

@ -84,10 +84,12 @@ def batch_inference(
prefix_char: str prefix_char: str
) -> List[str]: ) -> List[str]:
logits = chat_model.model(**batch_input).logits logits = chat_model.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
probs = torch.nn.functional.softmax( probs = torch.nn.functional.softmax(
torch.stack( torch.stack(
[ [
logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]] nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
for choice in choices for choice in choices
], ],
dim=-1 dim=-1
@ -120,8 +122,8 @@ def evaluate(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
template=template template=template
)) ))
chat_model.tokenizer.padding_side = "left" # avoid overflow issue in batched inference for llama2
eval_template = eval_templates[lang] eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = { category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"] subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]

View File

@ -289,8 +289,8 @@ register_template(
r""" r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
""" """
register_template( register_template(
name="llama2_zh", name="llama2_zh",
@ -307,7 +307,6 @@ register_template(
r""" r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
""" """
register_template( register_template(
name="alpaca", name="alpaca",
@ -328,8 +327,8 @@ register_template(
r""" r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 https://huggingface.co/lmsys/vicuna-13b-v1.5
""" """
register_template( register_template(
name="vicuna", name="vicuna",
@ -365,44 +364,9 @@ register_template(
) )
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nBot: "
],
system="",
sep=[
"\n"
]
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system="",
sep=[
"\n"
]
)
r""" r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
""" """
register_template( register_template(
name="ziya", name="ziya",
@ -424,6 +388,8 @@ register_template(
r""" r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B Supports: https://huggingface.co/BAAI/AquilaChat-7B
https://huggingface.co/BAAI/AquilaChat2-7B
https://huggingface.co/BAAI/AquilaChat2-34B
""" """
register_template( register_template(
name="aquila", name="aquila",
@ -449,6 +415,7 @@ register_template(
r""" r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b Supports: https://huggingface.co/internlm/internlm-chat-7b
https://huggingface.co/internlm/internlm-chat-20b
""" """
register_template( register_template(
name="intern", name="intern",
@ -542,6 +509,7 @@ register_template(
r""" r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
https://huggingface.co/Qwen/Qwen-14B-Chat
""" """
register_template( register_template(
name="chatml", name="chatml",
@ -591,7 +559,29 @@ register_template(
r""" r"""
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat Supports: https://huggingface.co/openchat/openchat_v3.2_super
"""
register_template(
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
"GPT4 User: {{query}}",
{"token": "<|end_of_turn|>"},
"GPT4 Assistant: "
],
system="",
sep=[
{"token": "<|end_of_turn|>"}
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
https://huggingface.co/xverse/XVERSE-13B-Chat
""" """
register_template( register_template(
name="xverse", name="xverse",

View File

@ -113,6 +113,8 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception: except Exception:
if self.dataset is not None:
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
dataset_info = None dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]