Merge pull request #2226 from hiyouga/dev

support function calling

Former-commit-id: 9986cc6dd1d10d8ed4dd7140b2a8295c9d7fd72b
This commit is contained in:
hoshi-hiyouga 2024-01-18 14:31:28 +08:00 committed by GitHub
commit 10d595b507
72 changed files with 1383 additions and 1109 deletions

View File

@ -55,14 +55,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention. [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
@ -95,14 +97,13 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| Model | Model size | Default module | Template | | Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://huggingface.co/internlm) | 7B/20B | q_proj,v_proj | intern | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | | [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
@ -183,6 +184,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
</details> </details>

View File

@ -55,6 +55,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 更新日志 ## 更新日志
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
@ -95,14 +97,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | | -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [Baichuan](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://huggingface.co/internlm) | 7B/20B | q_proj,v_proj | intern | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | | [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
@ -183,6 +184,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
</details> </details>

View File

@ -165,9 +165,13 @@
"hf_hub_url": "HuggingFaceH4/ultrachat_200k", "hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k", "ms_hub_url": "AI-ModelScope/ultrachat_200k",
"columns": { "columns": {
"messages": "messages", "messages": "messages"
"role": "role", },
"content": "content" "tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "human",
"assistant_tag": "assistant"
}, },
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
@ -180,9 +184,13 @@
"hf_hub_url": "lmsys/lmsys-chat-1m", "hf_hub_url": "lmsys/lmsys-chat-1m",
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m", "ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
"columns": { "columns": {
"messages": "conversation", "messages": "conversation"
"role": "role", },
"content": "content" "tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "human",
"assistant_tag": "assistant"
}, },
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
@ -190,6 +198,14 @@
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k", "hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
"glaive_toolcall": {
"file_name": "glaive_toolcall_10k.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"tools": "tools"
}
},
"hh_rlhf_en": { "hh_rlhf_en": {
"script_url": "hh_rlhf_en", "script_url": "hh_rlhf_en",
"columns": { "columns": {

View File

@ -0,0 +1 @@
4748dff00d1dc42768a5b6cc772143c313017812

View File

@ -9,7 +9,6 @@ scipy
einops einops
sentencepiece sentencepiece
protobuf protobuf
tiktoken
jieba jieba
rouge-chinese rouge-chinese
nltk nltk

View File

@ -8,3 +8,12 @@ from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.4.0" __version__ = "0.4.0"
__all__ = [
"create_app",
"ChatModel",
"Evaluator",
"export_model",
"run_exp",
"create_ui",
"create_web_demo"
]

View File

@ -1 +1,4 @@
from llmtuner.api.app import create_app from .app import create_app
__all__ = ["create_app"]

View File

@ -5,7 +5,7 @@ from typing import List, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from llmtuner.api.protocol import ( from .protocol import (
Role, Role,
Finish, Finish,
ModelCard, ModelCard,
@ -21,9 +21,9 @@ from llmtuner.api.protocol import (
ScoreEvaluationRequest, ScoreEvaluationRequest,
ScoreEvaluationResponse ScoreEvaluationResponse
) )
from llmtuner.chat import ChatModel from ..chat import ChatModel
from llmtuner.extras.misc import torch_gc from ..extras.misc import torch_gc
from llmtuner.extras.packages import ( from ..extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available is_fastapi_availble, is_starlette_available, is_uvicorn_available
) )

View File

@ -1,15 +1,17 @@
import time import time
from enum import Enum from enum import Enum, unique
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional from typing import List, Optional
@unique
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
@unique
class Finish(str, Enum): class Finish(str, Enum):
STOP = "stop" STOP = "stop"
LENGTH = "length" LENGTH = "length"

View File

@ -1 +1,4 @@
from llmtuner.chat.chat_model import ChatModel from .chat_model import ChatModel
__all__ = ["ChatModel"]

View File

@ -1,13 +1,13 @@
import torch import torch
import tiktoken
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.data.template import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer, Role
from llmtuner.extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
@dataclass @dataclass
@ -36,10 +36,19 @@ class ChatModel:
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
for old_prompt, old_response in history:
messages.append({"role": Role.USER, "content": old_prompt})
messages.append({"role": Role.ASSISTANT, "content": old_response})
messages.append({"role": Role.USER, "content": query})
messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
) )
prompt_length = len(prompt) prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)
@ -90,6 +99,7 @@ class ChatModel:
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs **input_kwargs
) -> List[Response]: ) -> List[Response]:
r""" r"""
@ -97,7 +107,7 @@ class ChatModel:
Returns: [(response_text, prompt_length, response_length)] * n (default n=1) Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
""" """
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs) gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode( response = self.tokenizer.batch_decode(
@ -122,9 +132,10 @@ class ChatModel:
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs) gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@ -139,11 +150,6 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs **input_kwargs
) -> List[float]: ) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda") device = getattr(self.model.pretrained_model, "device", "cuda")
@ -153,7 +159,7 @@ class ChatModel:
truncation=True, truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt", return_tensors="pt",
**kwargs add_special_tokens=True
).to(device) ).to(device)
input_ids: torch.Tensor = inputs["input_ids"] input_ids: torch.Tensor = inputs["input_ids"]

View File

@ -1,4 +1,6 @@
from llmtuner.data.loader import get_dataset from .loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset from .template import get_template_and_fix_tokenizer, templates
from llmtuner.data.template import get_template_and_fix_tokenizer from .utils import split_dataset, Role
from llmtuner.data.utils import split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]

View File

@ -0,0 +1,106 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from .utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history:
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT, "content": old_response})
instruction = examples[dataset_attr.prompt][i]
if dataset_attr.query and examples[dataset_attr.query][i]:
instruction += "\n" + examples[dataset_attr.query][i]
prompt.append({"role": Role.USER, "content": instruction})
if isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
else:
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION
}
for i, messages in enumerate(examples[dataset_attr.messages]):
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
prompt = []
response = []
for turn_idx, message in enumerate(messages):
if turn_idx % 2 == 0:
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
last_message = prompt.pop(-1)
response.append(last_message)
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}]
response: [{"role": "assistant", "content": "..."}]
system: "..."
tools: "..."
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
**kwargs
)

View File

@ -0,0 +1,102 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Union
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format to answer the question:\n"
"```\n"
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
"Action Input: the input to the action{format_prompt}.\n"
"```"
)
@dataclass
class StringFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
elements = []
for elem in self.container:
if isinstance(elem, str):
for name, value in kwargs.items():
elem = elem.replace("{{" + name + "}}", value)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class FunctionFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
function = json.loads(content)
name = json.dumps(function["name"], ensure_ascii=False)
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except json.JSONDecodeError:
name, arguments = "", ""
elements = []
for elem in self.container:
if isinstance(elem, str):
elem = elem.replace("{{name}}", name)
elem = elem.replace("{{arguments}}", arguments)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class ToolFormatter:
type: Literal["default"]
def _default(self, tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text,
tool_names=", ".join(tool_names),
format_prompt=JSON_FORMAT_PROMPT
)
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
tools = json.loads(content)
if not len(tools):
return [""]
if self.type == "default":
return [self._default(tools)]
except json.JSONDecodeError:
return [""]

View File

@ -1,160 +1,114 @@
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from llmtuner.data.utils import checksum from ..extras.constants import FILEEXT2TYPE
from llmtuner.extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger
from llmtuner.extras.logging import get_logger from .utils import checksum
from .parser import get_dataset_list
from .aligner import align_dataset
from .template import get_template_and_fix_tokenizer
from .preprocess import get_preprocess_and_print_func
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from .parser import DatasetAttr
from ..hparams import ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def get_dataset( def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments" data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]: ):
max_samples = data_args.max_samples data_path, data_name, data_dir, data_files = None, None, None, None
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
if data_args.cache_path is not None: elif dataset_attr.load_from == "script":
if os.path.exists(data_args.cache_path): data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
logger.warning("Loading dataset from disk will ignore other data arguments.") data_name = dataset_attr.subset
dataset = load_from_disk(data_args.cache_path) data_dir = dataset_attr.folder
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
elif data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
for dataset_attr in data_args.dataset_list: elif dataset_attr.load_from == "file":
logger.info("Loading dataset {}...".format(dataset_attr)) data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_path, data_name, data_dir, data_files = None, None, None, None if os.path.isdir(local_path): # is directory
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: for file_name in os.listdir(local_path):
data_path = dataset_attr.dataset_name data_files.append(os.path.join(local_path, file_name))
data_name = dataset_attr.subset if data_path is None:
data_dir = dataset_attr.folder data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif dataset_attr.load_from == "script": elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) raise ValueError("File types should be identical.")
data_name = dataset_attr.subset elif os.path.isfile(local_path): # is file
elif dataset_attr.load_from == "file": data_files.append(local_path)
data_files = [] data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
else: else:
raise NotImplementedError raise ValueError("File not found.")
if dataset_attr.load_from == "ms_hub": if data_path is None:
try: raise ValueError("File extension must be txt, csv, json or jsonl.")
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE checksum(data_files, dataset_attr.dataset_sha1)
dataset = MsDataset.load( else:
dataset_name=data_path, raise NotImplementedError
subset_name=data_name,
data_dir=data_dir, if dataset_attr.load_from == "ms_hub":
data_files=data_files, try:
split=data_args.split, from modelscope import MsDataset
cache_dir=cache_dir, from modelscope.utils.config_ds import MS_DATASETS_CACHE
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")) cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
).to_hf_dataset() dataset = MsDataset.load(
except ImportError: dataset_name=data_path,
raise ImportError("Please install modelscope via `pip install modelscope -U`") subset_name=data_name,
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir, data_dir=data_dir,
data_files=data_files, data_files=data_files,
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=cache_dir,
token=model_args.hf_hub_token, token=model_args.ms_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")) use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
) ).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples))) num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return align_dataset(dataset, dataset_attr, data_args)
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
for i, msg_list in enumerate(examples[dataset_attr.messages]):
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
msg_pairs = []
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.role] != user_role
or msg_list[idx+1][dataset_attr.role] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0: def merge_dataset(
outputs["prompt"].append(msg_pairs[-1][0]) all_datasets: List[Union["Dataset", "IterableDataset"]],
outputs["query"].append("") data_args: "DataArguments",
outputs["response"].append(msg_pairs[-1][1]) training_args: "Seq2SeqTrainingArguments"
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None) ) -> Union["Dataset", "IterableDataset"]:
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") if len(all_datasets) == 1:
return outputs
if dataset_attr.formatting == "sharegpt": # convert format
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
dataset = dataset.map(
convert_format,
batched=True,
remove_columns=column_names,
**kwargs
)
else:
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
@ -166,8 +120,72 @@ def get_dataset(
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
probabilities=data_args.interleave_probs, probabilities=data_args.interleave_probs,
seed=data_args.seed, seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy.")
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
tokenizer: "PreTrainedTokenizer",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset

101
src/llmtuner/data/parser.py Normal file
View File

@ -0,0 +1,101 @@
import os
import json
from typing import TYPE_CHECKING, List, Literal, Optional
from dataclasses import dataclass
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
system: Optional[str] = None
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
tools: Optional[str] = None
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
def __repr__(self) -> str:
return self.dataset_name
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
if "columns" in dataset_info[name]:
if dataset_attr.formatting == "alpaca":
column_names = ["prompt", "query", "response", "history"]
else:
column_names = ["messages", "tools"]
column_names += ["system"]
for column_name in column_names:
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
dataset_list.append(dataset_attr)
return dataset_list

View File

@ -1,272 +1,241 @@
import os from functools import partial
import tiktoken
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from llmtuner.data.template import get_template_and_fix_tokenizer from ..extras.constants import IGNORE_INDEX
from llmtuner.extras.constants import IGNORE_INDEX from ..extras.logging import get_logger
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__) logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: def preprocess_pretrain_dataset(
for i in range(len(examples["prompt"])): examples: Dict[str, List[Any]],
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", ) -> Dict[str, List[List[int]]]:
stage: Literal["pt", "sft", "rm", "ppo"] # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
) -> Union["Dataset", "IterableDataset"]: # for multiturn examples, we only mask the prompt part in each prompt-response pair.
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if data_args.cache_path is not None and os.path.exists(data_args.cache_path): for i in range(len(examples["prompt"])):
return dataset # already preprocessed if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
if data_args.train_on_prompt and template.efficient_eos: messages = examples["prompt"][i] + examples["response"][i]
raise ValueError("Current template does not support `train_on_prompt`.")
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
# make sure the saved tokenizer is the same as the original one
if hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], [] input_ids, labels = [], []
for query, response, history, system in construct_example(examples): for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
continue )):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( input_ids += source_ids + target_ids
tokenizer, query, response, history, system labels += source_mask + target_ids
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos: if template.efficient_eos:
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
total_length = len(input_ids) model_inputs["input_ids"].append(input_ids)
block_size = data_args.cutoff_len model_inputs["attention_mask"].append([1] * len(input_ids))
# we drop the small remainder, and if the total_length < block_size, we exclude this batch model_inputs["labels"].append(labels)
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples): def preprocess_packed_supervised_dataset(
if not (isinstance(query, str) and query != ""): examples: Dict[str, List[Any]],
continue tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
if template.efficient_eos: input_ids += source_ids + target_ids
labels += [tokenizer.eos_token_id] labels += source_mask + target_ids
if len(input_ids) > data_args.cutoff_len: if template.efficient_eos:
input_ids = input_ids[:data_args.cutoff_len] input_ids += [tokenizer.eos_token_id]
if len(labels) > data_args.cutoff_len: labels += [tokenizer.eos_token_id]
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids) total_length = len(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) block_size = data_args.cutoff_len
model_inputs["labels"].append(labels) # we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs return model_inputs
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) def preprocess_unsupervised_dataset(
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if template.efficient_eos: for i in range(len(examples["prompt"])):
chosen_ids += [tokenizer.eos_token_id] if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
rejected_ids += [tokenizer.eos_token_id] continue
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids)) messages = examples["prompt"][i] + examples["response"][i]
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) input_ids, labels = template.encode_oneturn(
if source_len > max_source_len: tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
) )
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): if template.efficient_eos:
if training_args.should_save: labels += [tokenizer.eos_token_id]
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log: model_inputs["input_ids"].append(input_ids)
try: model_inputs["attention_mask"].append([1] * len(input_ids))
print_function(next(iter(dataset))) model_inputs["labels"].append(labels)
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
import hashlib import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@ -12,6 +13,14 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
OBSERVATION = "observation"
FUNCTION = "function"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None: if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments", data_args: "DataArguments",

View File

@ -1 +1,4 @@
from llmtuner.eval.evaluator import Evaluator from .evaluator import Evaluator
__all__ = ["Evaluator"]

View File

@ -3,7 +3,6 @@
import os import os
import json import json
import torch import torch
import tiktoken
import numpy as np import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -11,10 +10,11 @@ from typing import Any, Dict, List, Optional
from datasets import load_dataset from datasets import load_dataset
from transformers.utils import cached_file from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template from .template import get_eval_template
from llmtuner.extras.constants import CHOICES, SUBJECTS from ..extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer from ..hparams import get_eval_args
from ..model import dispatch_model, load_model_and_tokenizer
class Evaluator: class Evaluator:
@ -26,15 +26,9 @@ class Evaluator:
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices() self.choice_inputs = [self.tokenizer.encode(
self.eval_template.prefix + ch, add_special_tokens=False
def _encode_choices(self) -> List[int]: )[-1] for ch in CHOICES]
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
@ -71,17 +65,17 @@ class Evaluator:
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example( messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i], target_data=dataset[self.data_args.split][i],
support_set=support_set, support_set=support_set,
subject_name=categorys[subject]["name"], subject_name=categorys[subject]["name"]
use_history=self.template.use_history
) )
input_ids, _ = self.template.encode_oneturn( input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history tokenizer=self.tokenizer, messages=messages
) )
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp) labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False): for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad( batch_input = self.tokenizer.pad(

View File

@ -1,7 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.extras.constants import CHOICES from ..extras.constants import CHOICES
from ..data import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset
@ -28,23 +29,26 @@ class EvalTemplate:
support_set: "Dataset", support_set: "Dataset",
subject_name: str, subject_name: str,
use_history: bool use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]: ) -> List[Dict[str, str]]:
query, resp = self.parse_example(target_data) messages = []
history = [self.parse_example(support_set[k]) for k in range(len(support_set))] for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
if len(history): prompt, response = self.parse_example(target_data)
temp = history.pop(0) messages.append({"role": Role.USER, "content": prompt})
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) messages.append({"role": Role.ASSISTANT, "content": response})
else:
query = self.system.format(subject=subject_name) + query messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history: if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query]) messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
history = []
return query.strip(), resp, history return messages
eval_templates: Dict[str, EvalTemplate] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template( def register_eval_template(
@ -62,7 +66,7 @@ def register_eval_template(
) )
def get_eval_template(name: str) -> EvalTemplate: def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None) eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name) assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template return eval_template

View File

@ -6,9 +6,9 @@ from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from llmtuner.extras.constants import LOG_FILE_NAME from .constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger from .logging import get_logger
from llmtuner.extras.misc import fix_valuehead_checkpoint from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -5,6 +5,8 @@ from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"] CHOICES = ["A", "B", "C", "D"]
DATA_CONFIG = "dataset_info.json"
DEFAULT_MODULE = defaultdict(str) DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str) DEFAULT_TEMPLATE = defaultdict(str)
@ -339,6 +341,30 @@ register_model_group(
) )
register_model_group(
models={
"InternLM2-7B": {
DownloadSource.DEFAULT: "internlm/internlm2-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b"
},
"InternLM2-20B": {
DownloadSource.DEFAULT: "internlm/internlm2-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b"
},
"InternLM2-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b"
},
"InternLM2-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b"
}
},
module="wqkv",
template="intern2"
)
register_model_group( register_model_group(
models={ models={
"LingoWhale-8B": { "LingoWhale-8B": {

View File

@ -13,8 +13,8 @@ from transformers.utils import (
) )
from peft import PeftModel from peft import PeftModel
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()

View File

@ -10,7 +10,7 @@ try:
except ImportError: except ImportError:
print("Please upgrade `transformers`.") print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available from ..packages import is_flash_attn2_available
if is_flash_attn2_available(): if is_flash_attn2_available():

View File

@ -4,8 +4,8 @@ import json
from typing import List, Optional from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger from .logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available from .packages import is_matplotlib_available
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_train_args, get_infer_args, get_eval_args
__all__ = [
"DataArguments",
"EvaluationArguments",
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"get_train_args",
"get_infer_args",
"get_eval_args"
]

View File

@ -1,40 +1,7 @@
import os from typing import Literal, Optional
import json
from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
DATA_CONFIG = "dataset_info.json"
def use_modelscope() -> bool:
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
@dataclass
class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
system: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
@dataclass @dataclass
class DataArguments: class DataArguments:
r""" r"""
@ -126,64 +93,3 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if self.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr(
"ms_hub",
dataset_name=dataset_info[name]["ms_hub_url"]
)
else:
dataset_attr = DatasetAttr(
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"]
)
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"script",
dataset_name=dataset_info[name]["script_url"]
)
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
self.dataset_list.append(dataset_attr)

View File

@ -43,13 +43,5 @@ class EvaluationArguments:
) )
def __post_init__(self): def __post_init__(self):
task_available = []
for folder in os.listdir(self.task_dir):
if os.path.isdir(os.path.join(self.task_dir, folder)):
task_available.append(folder)
if self.task not in task_available:
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
if self.save_dir is not None and os.path.exists(self.save_dir): if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.") raise ValueError("`save_dir` already exists, use another one.")

View File

@ -8,14 +8,12 @@ from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.hparams import ( from .data_args import DataArguments
ModelArguments, from .evaluation_args import EvaluationArguments
DataArguments, from .finetuning_args import FinetuningArguments
EvaluationArguments, from .generating_args import GeneratingArguments
FinetuningArguments, from .model_args import ModelArguments
GeneratingArguments
)
logger = get_logger(__name__) logger = get_logger(__name__)
@ -107,8 +105,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
_set_transformers_logging() _set_transformers_logging()
# Check arguments # Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None: if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")

View File

@ -1,5 +1,5 @@
# Level: loader > adapter > parser, utils from .loader import load_model_and_tokenizer
from .utils import dispatch_model, get_modelcard_args, load_valuehead_params
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args __all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"]
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params

View File

@ -3,12 +3,12 @@ from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.model.utils import find_all_linear_modules from .utils import find_all_linear_modules
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments from ..hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -4,15 +4,15 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.model.adapter import init_adapter from .adapter import init_adapter
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from llmtuner.model.utils import load_valuehead_params, register_autoclass from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments, FinetuningArguments from ..hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -10,15 +10,15 @@ from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTra
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.extras.misc import get_current_device, infer_optim_dtype from ..extras.misc import get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any, Dict, List
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.utils import cached_file from transformers.utils import cached_file
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.extras.misc import get_current_device from ..extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from ..hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -1 +1,4 @@
from llmtuner.train.tuner import export_model, run_exp from .tuner import export_model, run_exp
__all__ = ["export_model", "run_exp"]

View File

@ -1 +1,4 @@
from llmtuner.train.dpo.workflow import run_dpo from .workflow import run_dpo
__all__ = ["run_dpo"]

View File

@ -5,7 +5,7 @@ from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@ -3,18 +3,18 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments from ...hparams import ModelArguments
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding from ...train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer from ...train.dpo.trainer import CustomDPOTrainer
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments from ...hparams import DataArguments, FinetuningArguments
def run_dpo( def run_dpo(
@ -24,9 +24,8 @@ def run_dpo(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding( data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,

View File

@ -1 +1,4 @@
from llmtuner.train.ppo.workflow import run_ppo from .workflow import run_ppo
__all__ = ["run_ppo"]

View File

@ -13,15 +13,15 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
from llmtuner.extras.logging import get_logger from ...extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -2,7 +2,7 @@ import json
import torch import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from llmtuner.extras.packages import is_requests_available from ...extras.packages import is_requests_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@ -7,17 +7,17 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset from ...data import get_dataset
from llmtuner.extras.callbacks import FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model from ...train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer from ...train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo( def run_ppo(
@ -28,9 +28,8 @@ def run_ppo(
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

View File

@ -1 +1,4 @@
from llmtuner.train.pt.workflow import run_pt from .workflow import run_pt
__all__ = ["run_pt"]

View File

@ -4,14 +4,14 @@ import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from ...hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt( def run_pt(
@ -21,9 +21,8 @@ def run_pt(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer # Initialize our Trainer

View File

@ -1 +1,4 @@
from llmtuner.train.rm.workflow import run_rm from .workflow import run_rm
__all__ = ["run_rm"]

View File

@ -4,7 +4,7 @@ import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer from transformers import Trainer
from llmtuner.extras.logging import get_logger from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput

View File

@ -3,19 +3,19 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.callbacks import FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding from ...train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy from ...train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer from ...train.rm.trainer import PairwiseTrainer
from llmtuner.train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from ...hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm( def run_rm(
@ -25,9 +25,8 @@ def run_rm(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments # Update arguments

View File

@ -1 +1,4 @@
from llmtuner.train.sft.workflow import run_sft from .workflow import run_sft
__all__ = ["run_sft"]

View File

@ -2,8 +2,8 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import ( from ...extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available is_jieba_available, is_nltk_available, is_rouge_available
) )

View File

@ -6,8 +6,8 @@ import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput

View File

@ -3,18 +3,19 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from ...data import get_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from ...extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from ...extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer from ...model import load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics from ...train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer from ...train.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft( def run_sft(
@ -25,9 +26,8 @@ def run_sft(
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="sft")
if training_args.predict_with_generate: if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation tokenizer.padding_side = "left" # use left-padding in generation

View File

@ -2,14 +2,15 @@ import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import PreTrainedModel from transformers import PreTrainedModel
from llmtuner.extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer from ..hparams import get_train_args, get_infer_args
from llmtuner.train.pt import run_pt from ..model import load_model_and_tokenizer
from llmtuner.train.sft import run_sft from .pt import run_pt
from llmtuner.train.rm import run_rm from .sft import run_sft
from llmtuner.train.ppo import run_ppo from .rm import run_rm
from llmtuner.train.dpo import run_dpo from .ppo import run_ppo
from .dpo import run_dpo
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback

View File

@ -1,15 +1,15 @@
import torch import torch
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from llmtuner.extras.logging import get_logger from ..extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments from ..hparams import ModelArguments, FinetuningArguments
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -1 +1,4 @@
from llmtuner.webui.interface import create_ui, create_web_demo from .interface import create_ui, create_web_demo
__all__ = ["create_ui", "create_web_demo"]

View File

@ -2,14 +2,14 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat import ChatModel from ..chat import ChatModel
from llmtuner.extras.misc import torch_gc from ..extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments from ..hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir from .common import get_save_dir
from llmtuner.webui.locales import ALERTS from .locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from llmtuner.webui.manager import Manager from .manager import Manager
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
@ -105,6 +105,7 @@ class WebChatModel(ChatModel):
query: str, query: str,
history: List[Tuple[str, str]], history: List[Tuple[str, str]],
system: str, system: str,
tools: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
@ -112,7 +113,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
new_history = history + [(query, response)] new_history = history + [(query, response)]

View File

@ -5,7 +5,8 @@ from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
from llmtuner.extras.constants import ( from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
PEFT_METHODS, PEFT_METHODS,
@ -13,8 +14,7 @@ from llmtuner.extras.constants import (
TRAINING_STAGES, TRAINING_STAGES,
DownloadSource DownloadSource
) )
from llmtuner.extras.misc import use_modelscope from ..extras.misc import use_modelscope
from llmtuner.hparams.data_args import DATA_CONFIG
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}

View File

@ -1,6 +1,11 @@
from llmtuner.webui.components.top import create_top from .top import create_top
from llmtuner.webui.components.train import create_train_tab from .train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab from .eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab from .infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab from .export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box from .chatbot import create_chat_box
__all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
]

View File

@ -1,10 +1,14 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional, Tuple
from ..utils import check_json_schema
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.blocks import Block from gradio.blocks import Block
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_chat_box( def create_chat_box(
@ -17,6 +21,7 @@ def create_chat_box(
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
system = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=2)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@ -27,9 +32,11 @@ def create_chat_box(
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
tools.input(check_json_schema, [tools])
submit_btn.click( submit_btn.click(
engine.chatter.predict, engine.chatter.predict,
[chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
).then( ).then(
@ -40,6 +47,7 @@ def create_chat_box(
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, dict(
system=system, system=system,
tools=tools,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
clear_btn=clear_btn, clear_btn=clear_btn,

View File

@ -3,7 +3,7 @@ import json
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Any, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict, Tuple
from llmtuner.webui.common import DATA_CONFIG from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -1,12 +1,13 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from ..common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from .data import create_preview_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@ -1,13 +1,14 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.train import export_model from ...train import export_model
from llmtuner.webui.common import get_save_dir from ..common import get_save_dir
from llmtuner.webui.locales import ALERTS from ..locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"] GPTQ_BITS = ["8", "4", "3", "2"]

View File

@ -1,11 +1,12 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.webui.components.chatbot import create_chat_box from .chatbot import create_chat_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@ -1,10 +1,10 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates from ...data import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config from ..common import get_model_path, get_template, list_adapters, save_config
from llmtuner.webui.utils import can_quantize from ..utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -2,14 +2,15 @@ import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from llmtuner.extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from ..components.data import create_preview_box
from llmtuner.webui.utils import gen_plot from ..utils import gen_plot
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@ -2,12 +2,12 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel from .chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, load_config from .common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES from .locales import LOCALES
from llmtuner.webui.manager import Manager from .manager import Manager
from llmtuner.webui.runner import Runner from .runner import Runner
from llmtuner.webui.utils import get_time from .utils import get_time
class Engine: class Engine:

View File

@ -2,7 +2,7 @@ import gradio as gr
from typing import Optional from typing import Optional
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from llmtuner.webui.components import ( from .components import (
create_top, create_top,
create_train_tab, create_train_tab,
create_eval_tab, create_eval_tab,
@ -10,9 +10,9 @@ from llmtuner.webui.components import (
create_export_tab, create_export_tab,
create_chat_box create_chat_box
) )
from llmtuner.webui.common import save_config from .common import save_config
from llmtuner.webui.css import CSS from .css import CSS
from llmtuner.webui.engine import Engine from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"") require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")

View File

@ -521,6 +521,14 @@ LOCALES = {
"placeholder": "系统提示词(非必填)" "placeholder": "系统提示词(非必填)"
} }
}, },
"tools": {
"en": {
"placeholder": "Tools (optional)"
},
"zh": {
"placeholder": "工具列表(非必填)"
}
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input..." "placeholder": "Input..."

View File

@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from llmtuner.extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES from ..extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler from ..extras.logging import LoggerHandler
from llmtuner.extras.misc import get_device_count, torch_gc from ..extras.misc import get_device_count, torch_gc
from llmtuner.train import run_exp from ..train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config from .common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS from .locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING: if TYPE_CHECKING:
from llmtuner.webui.manager import Manager from .manager import Manager
class Runner: class Runner:

View File

@ -4,12 +4,12 @@ import gradio as gr
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available from ..extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth from ..extras.ploting import smooth
from llmtuner.webui.common import get_save_dir from .common import get_save_dir
if TYPE_CHECKING: if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback from ..extras.callbacks import LogCallback
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.figure import matplotlib.figure
@ -41,6 +41,13 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True) return gr.update(interactive=True)
def check_json_schema(text: str) -> None:
try:
json.loads(text)
except json.JSONDecodeError:
gr.Warning("Invalid JSON schema")
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: Dict[str, Any]) -> str:
args.pop("disable_tqdm", None) args.pop("disable_tqdm", None)
args["plot_loss"] = args.get("do_train", None) args["plot_loss"] = args.get("do_train", None)

View File

@ -1,6 +1,7 @@
# coding=utf-8 # coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2. # Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB # Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import os import os
import fire import fire
@ -43,19 +44,18 @@ def save_weight(
llama2_state_dict[key.replace("output", "lm_head")] = value llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key: elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "wqkv" in key: elif "wqkv" in key:
proj_size = value.size(0)
num_q_heads = internlm2_config_dict["num_attention_heads"] num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"] num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
elif "wo" in key: elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key: elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key: elif "w1" in key: