diff --git a/README.md b/README.md
index 84ca47ee..1ec9eb50 100644
--- a/README.md
+++ b/README.md
@@ -55,14 +55,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
+[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
+
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
-[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
-
Full Changelog
+[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
+
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
@@ -95,14 +97,13 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
-| [Baichuan](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
-| [InternLM](https://huggingface.co/internlm) | 7B/20B | q_proj,v_proj | intern |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
@@ -183,6 +184,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
+- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
diff --git a/README_zh.md b/README_zh.md
index 498bdb8d..a2ce1e14 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -55,6 +55,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 更新日志
+[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
+
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
@@ -95,14 +97,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
-| [Baichuan](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
-| [InternLM](https://huggingface.co/internlm) | 7B/20B | q_proj,v_proj | intern |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
@@ -183,6 +184,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
+- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
diff --git a/data/dataset_info.json b/data/dataset_info.json
index bc031d76..ee35ac0e 100644
--- a/data/dataset_info.json
+++ b/data/dataset_info.json
@@ -165,9 +165,13 @@
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
"columns": {
- "messages": "messages",
- "role": "role",
- "content": "content"
+ "messages": "messages"
+ },
+ "tags": {
+ "role_tag": "role",
+ "content_tag": "content",
+ "user_tag": "human",
+ "assistant_tag": "assistant"
},
"formatting": "sharegpt"
},
@@ -180,9 +184,13 @@
"hf_hub_url": "lmsys/lmsys-chat-1m",
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
"columns": {
- "messages": "conversation",
- "role": "role",
- "content": "content"
+ "messages": "conversation"
+ },
+ "tags": {
+ "role_tag": "role",
+ "content_tag": "content",
+ "user_tag": "human",
+ "assistant_tag": "assistant"
},
"formatting": "sharegpt"
},
@@ -190,6 +198,14 @@
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
"formatting": "sharegpt"
},
+ "glaive_toolcall": {
+ "file_name": "glaive_toolcall_10k.json",
+ "formatting": "sharegpt",
+ "columns": {
+ "messages": "conversations",
+ "tools": "tools"
+ }
+ },
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"columns": {
diff --git a/data/glaive_toolcall_10k.json.REMOVED.git-id b/data/glaive_toolcall_10k.json.REMOVED.git-id
new file mode 100644
index 00000000..64693b28
--- /dev/null
+++ b/data/glaive_toolcall_10k.json.REMOVED.git-id
@@ -0,0 +1 @@
+4748dff00d1dc42768a5b6cc772143c313017812
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index ce3c92a3..0e9090e4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,7 +9,6 @@ scipy
einops
sentencepiece
protobuf
-tiktoken
jieba
rouge-chinese
nltk
diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py
index 932dd56b..bdbd5af2 100644
--- a/src/llmtuner/__init__.py
+++ b/src/llmtuner/__init__.py
@@ -8,3 +8,12 @@ from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.4.0"
+__all__ = [
+ "create_app",
+ "ChatModel",
+ "Evaluator",
+ "export_model",
+ "run_exp",
+ "create_ui",
+ "create_web_demo"
+]
diff --git a/src/llmtuner/api/__init__.py b/src/llmtuner/api/__init__.py
index b3ce183a..d7059fbd 100644
--- a/src/llmtuner/api/__init__.py
+++ b/src/llmtuner/api/__init__.py
@@ -1 +1,4 @@
-from llmtuner.api.app import create_app
+from .app import create_app
+
+
+__all__ = ["create_app"]
diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py
index f130eab6..f8115227 100644
--- a/src/llmtuner/api/app.py
+++ b/src/llmtuner/api/app.py
@@ -5,7 +5,7 @@ from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
-from llmtuner.api.protocol import (
+from .protocol import (
Role,
Finish,
ModelCard,
@@ -21,9 +21,9 @@ from llmtuner.api.protocol import (
ScoreEvaluationRequest,
ScoreEvaluationResponse
)
-from llmtuner.chat import ChatModel
-from llmtuner.extras.misc import torch_gc
-from llmtuner.extras.packages import (
+from ..chat import ChatModel
+from ..extras.misc import torch_gc
+from ..extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)
diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py
index a5b5c81d..42569f84 100644
--- a/src/llmtuner/api/protocol.py
+++ b/src/llmtuner/api/protocol.py
@@ -1,15 +1,17 @@
import time
-from enum import Enum
+from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional
+@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
+@unique
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
diff --git a/src/llmtuner/chat/__init__.py b/src/llmtuner/chat/__init__.py
index f86efe96..702d0ac7 100644
--- a/src/llmtuner/chat/__init__.py
+++ b/src/llmtuner/chat/__init__.py
@@ -1 +1,4 @@
-from llmtuner.chat.chat_model import ChatModel
+from .chat_model import ChatModel
+
+
+__all__ = ["ChatModel"]
diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py
index 0c2f9c92..88846dee 100644
--- a/src/llmtuner/chat/chat_model.py
+++ b/src/llmtuner/chat/chat_model.py
@@ -1,13 +1,13 @@
import torch
-import tiktoken
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
-from llmtuner.data.template import get_template_and_fix_tokenizer
-from llmtuner.extras.misc import get_logits_processor
-from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
+from ..data import get_template_and_fix_tokenizer, Role
+from ..extras.misc import get_logits_processor
+from ..model import dispatch_model, load_model_and_tokenizer
+from ..hparams import get_infer_args
@dataclass
@@ -36,10 +36,19 @@ class ChatModel:
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
+ tools: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
+ messages = []
+ if history is not None:
+ for old_prompt, old_response in history:
+ messages.append({"role": Role.USER, "content": old_prompt})
+ messages.append({"role": Role.ASSISTANT, "content": old_response})
+
+ messages.append({"role": Role.USER, "content": query})
+ messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn(
- tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
+ tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
@@ -90,6 +99,7 @@ class ChatModel:
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
+ tools: Optional[str] = None,
**input_kwargs
) -> List[Response]:
r"""
@@ -97,7 +107,7 @@ class ChatModel:
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
- gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
+ gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
@@ -122,9 +132,10 @@ class ChatModel:
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
+ tools: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
- gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
+ gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -139,11 +150,6 @@ class ChatModel:
batch_input: List[str],
**input_kwargs
) -> List[float]:
- if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
- kwargs = dict(allowed_special="all")
- else:
- kwargs = dict(add_special_tokens=True)
-
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
@@ -153,7 +159,7 @@ class ChatModel:
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
- **kwargs
+ add_special_tokens=True
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py
index 35f7caa3..85be70b7 100644
--- a/src/llmtuner/data/__init__.py
+++ b/src/llmtuner/data/__init__.py
@@ -1,4 +1,6 @@
-from llmtuner.data.loader import get_dataset
-from llmtuner.data.preprocess import preprocess_dataset
-from llmtuner.data.template import get_template_and_fix_tokenizer
-from llmtuner.data.utils import split_dataset
+from .loader import get_dataset
+from .template import get_template_and_fix_tokenizer, templates
+from .utils import split_dataset, Role
+
+
+__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]
diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py
new file mode 100644
index 00000000..f724c8a0
--- /dev/null
+++ b/src/llmtuner/data/aligner.py
@@ -0,0 +1,106 @@
+from functools import partial
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from .utils import Role
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+
+ from ..hparams import DataArguments
+ from .parser import DatasetAttr
+
+
+def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
+ outputs = {"prompt": [], "response": [], "system": [], "tools": []}
+ for i in range(len(examples[dataset_attr.prompt])):
+ prompt = []
+ if dataset_attr.history:
+ for old_prompt, old_response in examples[dataset_attr.history][i]:
+ prompt.append({"role": Role.USER, "content": old_prompt})
+ prompt.append({"role": Role.ASSISTANT, "content": old_response})
+
+ instruction = examples[dataset_attr.prompt][i]
+ if dataset_attr.query and examples[dataset_attr.query][i]:
+ instruction += "\n" + examples[dataset_attr.query][i]
+ prompt.append({"role": Role.USER, "content": instruction})
+
+ if isinstance(examples[dataset_attr.response][i], list):
+ response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
+ else:
+ response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
+
+ outputs["prompt"].append(prompt)
+ outputs["response"].append(response)
+ outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
+ outputs["tools"].append("")
+
+ return outputs
+
+
+def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
+ outputs = {"prompt": [], "response": [], "system": [], "tools": []}
+ tag_mapping = {
+ dataset_attr.user_tag: Role.USER,
+ dataset_attr.assistant_tag: Role.ASSISTANT,
+ dataset_attr.observation_tag: Role.OBSERVATION,
+ dataset_attr.function_tag: Role.FUNCTION
+ }
+ for i, messages in enumerate(examples[dataset_attr.messages]):
+ messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
+ if len(messages) == 0:
+ continue
+
+ prompt = []
+ response = []
+ for turn_idx, message in enumerate(messages):
+ if turn_idx % 2 == 0:
+ accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
+ else:
+ accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
+
+ if message[dataset_attr.role_tag] not in accept_tags:
+ raise ValueError("Invalid role tag in {}.".format(messages))
+
+ prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
+
+ last_message = prompt.pop(-1)
+ response.append(last_message)
+ outputs["prompt"].append(prompt)
+ outputs["response"].append(response)
+ outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
+ outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
+
+ return outputs
+
+
+def align_dataset(
+ dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
+) -> Union["Dataset", "IterableDataset"]:
+ r"""
+ Aligned dataset:
+ prompt: [{"role": "user", "content": "..."}]
+ response: [{"role": "assistant", "content": "..."}]
+ system: "..."
+ tools: "..."
+ """
+ if dataset_attr.formatting == "alpaca":
+ convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
+ else:
+ convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
+
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache),
+ desc="Converting format of dataset"
+ )
+
+ return dataset.map(
+ convert_func,
+ batched=True,
+ remove_columns=column_names,
+ **kwargs
+ )
diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py
new file mode 100644
index 00000000..ce7b2819
--- /dev/null
+++ b/src/llmtuner/data/formatter.py
@@ -0,0 +1,102 @@
+import json
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Union
+
+
+JSON_FORMAT_PROMPT = (
+ """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
+)
+
+
+TOOL_SYSTEM_PROMPT = (
+ "You have access to the following tools:\n{tool_text}"
+ "Use the following format to answer the question:\n"
+ "```\n"
+ "Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
+ "Action Input: the input to the action{format_prompt}.\n"
+ "```"
+)
+
+
+@dataclass
+class StringFormatter:
+ container: List[Union[str, Dict[str, str]]]
+
+ def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
+ elements = []
+ for elem in self.container:
+ if isinstance(elem, str):
+ for name, value in kwargs.items():
+ elem = elem.replace("{{" + name + "}}", value)
+ elements.append(elem)
+ elif isinstance(elem, (dict, set)):
+ elements.append(elem)
+ else:
+ raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
+
+ return elements
+
+
+@dataclass
+class FunctionFormatter:
+ container: List[Union[str, Dict[str, str]]]
+
+ def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
+ try:
+ function = json.loads(content)
+ name = json.dumps(function["name"], ensure_ascii=False)
+ arguments = json.dumps(function["arguments"], ensure_ascii=False)
+ except json.JSONDecodeError:
+ name, arguments = "", ""
+
+ elements = []
+ for elem in self.container:
+ if isinstance(elem, str):
+ elem = elem.replace("{{name}}", name)
+ elem = elem.replace("{{arguments}}", arguments)
+ elements.append(elem)
+ elif isinstance(elem, (dict, set)):
+ elements.append(elem)
+ else:
+ raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
+
+ return elements
+
+
+@dataclass
+class ToolFormatter:
+ type: Literal["default"]
+
+ def _default(self, tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ tool_names = []
+ for tool in tools:
+ param_text = ""
+ for name, param in tool["parameters"]["properties"].items():
+ required = ", required" if name in tool["parameters"].get("required", []) else ""
+ enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
+ param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
+ name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
+ )
+
+ tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
+ name=tool["name"], desc=tool.get("description", ""), args=param_text
+ )
+ tool_names.append(tool["name"])
+
+ return TOOL_SYSTEM_PROMPT.format(
+ tool_text=tool_text,
+ tool_names=", ".join(tool_names),
+ format_prompt=JSON_FORMAT_PROMPT
+ )
+
+ def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
+ try:
+ tools = json.loads(content)
+ if not len(tools):
+ return [""]
+
+ if self.type == "default":
+ return [self._default(tools)]
+ except json.JSONDecodeError:
+ return [""]
diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py
index f9019c8b..87f42558 100644
--- a/src/llmtuner/data/loader.py
+++ b/src/llmtuner/data/loader.py
@@ -1,160 +1,114 @@
import os
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
-from llmtuner.data.utils import checksum
-from llmtuner.extras.constants import FILEEXT2TYPE
-from llmtuner.extras.logging import get_logger
+from ..extras.constants import FILEEXT2TYPE
+from ..extras.logging import get_logger
+from .utils import checksum
+from .parser import get_dataset_list
+from .aligner import align_dataset
+from .template import get_template_and_fix_tokenizer
+from .preprocess import get_preprocess_and_print_func
+
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
- from llmtuner.hparams import ModelArguments, DataArguments
+ from transformers import Seq2SeqTrainingArguments
+ from transformers.tokenization_utils import PreTrainedTokenizer
+
+ from .parser import DatasetAttr
+ from ..hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
-def get_dataset(
+def load_single_dataset(
+ dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
- data_args: "DataArguments"
-) -> Union["Dataset", "IterableDataset"]:
- max_samples = data_args.max_samples
- all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
+ data_args: "DataArguments",
+):
+ data_path, data_name, data_dir, data_files = None, None, None, None
+ if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
+ data_path = dataset_attr.dataset_name
+ data_name = dataset_attr.subset
+ data_dir = dataset_attr.folder
- if data_args.cache_path is not None:
- if os.path.exists(data_args.cache_path):
- logger.warning("Loading dataset from disk will ignore other data arguments.")
- dataset = load_from_disk(data_args.cache_path)
- if data_args.streaming:
- dataset = dataset.to_iterable_dataset()
- return dataset
- elif data_args.streaming:
- raise ValueError("Turn off dataset streaming to save cache files.")
+ elif dataset_attr.load_from == "script":
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ data_name = dataset_attr.subset
+ data_dir = dataset_attr.folder
- for dataset_attr in data_args.dataset_list:
- logger.info("Loading dataset {}...".format(dataset_attr))
-
- data_path, data_name, data_dir, data_files = None, None, None, None
- if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
- data_path = dataset_attr.dataset_name
- data_name = dataset_attr.subset
- data_dir = dataset_attr.folder
- elif dataset_attr.load_from == "script":
- data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
- data_name = dataset_attr.subset
- elif dataset_attr.load_from == "file":
- data_files = []
- local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
- if os.path.isdir(local_path): # is directory
- for file_name in os.listdir(local_path):
- data_files.append(os.path.join(local_path, file_name))
- if data_path is None:
- data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
- else:
- assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
- elif os.path.isfile(local_path): # is file
- data_files.append(local_path)
- data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
- else:
- raise ValueError("File not found.")
-
- assert data_path, "File extension must be txt, csv, json or jsonl."
- checksum(data_files, dataset_attr.dataset_sha1)
+ elif dataset_attr.load_from == "file":
+ data_files = []
+ local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ if os.path.isdir(local_path): # is directory
+ for file_name in os.listdir(local_path):
+ data_files.append(os.path.join(local_path, file_name))
+ if data_path is None:
+ data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
+ elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
+ raise ValueError("File types should be identical.")
+ elif os.path.isfile(local_path): # is file
+ data_files.append(local_path)
+ data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
- raise NotImplementedError
+ raise ValueError("File not found.")
- if dataset_attr.load_from == "ms_hub":
- try:
- from modelscope import MsDataset
- from modelscope.utils.config_ds import MS_DATASETS_CACHE
+ if data_path is None:
+ raise ValueError("File extension must be txt, csv, json or jsonl.")
- cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
- dataset = MsDataset.load(
- dataset_name=data_path,
- subset_name=data_name,
- data_dir=data_dir,
- data_files=data_files,
- split=data_args.split,
- cache_dir=cache_dir,
- token=model_args.ms_hub_token,
- use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
- ).to_hf_dataset()
- except ImportError:
- raise ImportError("Please install modelscope via `pip install modelscope -U`")
- else:
- dataset = load_dataset(
- path=data_path,
- name=data_name,
+ checksum(data_files, dataset_attr.dataset_sha1)
+ else:
+ raise NotImplementedError
+
+ if dataset_attr.load_from == "ms_hub":
+ try:
+ from modelscope import MsDataset
+ from modelscope.utils.config_ds import MS_DATASETS_CACHE
+
+ cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
+ dataset = MsDataset.load(
+ dataset_name=data_path,
+ subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
- cache_dir=model_args.cache_dir,
- token=model_args.hf_hub_token,
- streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
- )
+ cache_dir=cache_dir,
+ token=model_args.ms_hub_token,
+ use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
+ ).to_hf_dataset()
+ except ImportError:
+ raise ImportError("Please install modelscope via `pip install modelscope -U`")
+ else:
+ dataset = load_dataset(
+ path=data_path,
+ name=data_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=data_args.split,
+ cache_dir=model_args.cache_dir,
+ token=model_args.hf_hub_token,
+ streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
+ )
- if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
- dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
+ if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
+ dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
- if max_samples is not None: # truncate dataset
- dataset = dataset.select(range(min(len(dataset), max_samples)))
+ if data_args.max_samples is not None: # truncate dataset
+ num_samples = min(data_args.max_samples, len(dataset))
+ dataset = dataset.select(range(num_samples))
- def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
- # convert dataset from sharegpt format to alpaca format
- outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
- for i, msg_list in enumerate(examples[dataset_attr.messages]):
- msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
- if len(msg_list) == 0:
- continue
+ return align_dataset(dataset, dataset_attr, data_args)
- msg_pairs = []
- user_role, assistant_role = None, None
- for idx in range(0, len(msg_list), 2):
- if user_role is None and assistant_role is None:
- user_role = msg_list[idx][dataset_attr.role]
- assistant_role = msg_list[idx + 1][dataset_attr.role]
- else:
- if (
- msg_list[idx][dataset_attr.role] != user_role
- or msg_list[idx+1][dataset_attr.role] != assistant_role
- ):
- raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
- msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
- if len(msg_pairs) != 0:
- outputs["prompt"].append(msg_pairs[-1][0])
- outputs["query"].append("")
- outputs["response"].append(msg_pairs[-1][1])
- outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
- outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
-
- return outputs
-
- if dataset_attr.formatting == "sharegpt": # convert format
- column_names = list(next(iter(dataset)).keys())
- kwargs = {}
- if not data_args.streaming:
- kwargs = dict(
- num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=(not data_args.overwrite_cache),
- desc="Converting format of dataset"
- )
-
- dataset = dataset.map(
- convert_format,
- batched=True,
- remove_columns=column_names,
- **kwargs
- )
- else:
- for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
- if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
- dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
-
- all_datasets.append(dataset)
-
- if len(data_args.dataset_list) == 1:
+def merge_dataset(
+ all_datasets: List[Union["Dataset", "IterableDataset"]],
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments"
+) -> Union["Dataset", "IterableDataset"]:
+ if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
@@ -166,8 +120,72 @@ def get_dataset(
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
- seed=data_args.seed,
+ seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")
+
+
+def get_dataset(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ tokenizer: "PreTrainedTokenizer",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo"],
+ # split: Optional[str] = "train", # TODO: add split
+) -> Union["Dataset", "IterableDataset"]:
+ template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
+ if data_args.train_on_prompt and template.efficient_eos:
+ raise ValueError("Current template does not support `train_on_prompt`.")
+
+ # Load from cache
+ if data_args.cache_path is not None:
+ if os.path.exists(data_args.cache_path):
+ logger.warning("Loading dataset from disk will ignore other data arguments.")
+ dataset = load_from_disk(data_args.cache_path)
+ if data_args.streaming:
+ dataset = dataset.to_iterable_dataset()
+ return dataset
+
+ if data_args.streaming:
+ raise ValueError("Turn off dataset streaming to save cache files.")
+
+ with training_args.main_process_first(desc="load dataset"):
+ all_datasets = []
+ for dataset_attr in get_dataset_list(data_args): # TODO: add split
+ all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
+ dataset = merge_dataset(all_datasets, data_args, training_args)
+
+ with training_args.main_process_first(desc="pre-process dataset"):
+ preprocess_func, print_function = get_preprocess_and_print_func(
+ tokenizer, template, data_args, training_args, stage
+ )
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache),
+ desc="Running tokenizer on dataset"
+ )
+
+ dataset = dataset.map(
+ preprocess_func,
+ batched=True,
+ remove_columns=column_names,
+ **kwargs
+ )
+
+ if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
+ if training_args.should_save:
+ dataset.save_to_disk(data_args.cache_path)
+ logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
+
+ if training_args.should_log:
+ try:
+ print_function(next(iter(dataset)))
+ except StopIteration:
+ raise RuntimeError("Empty dataset!")
+
+ return dataset
diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py
new file mode 100644
index 00000000..9da5b732
--- /dev/null
+++ b/src/llmtuner/data/parser.py
@@ -0,0 +1,101 @@
+import os
+import json
+from typing import TYPE_CHECKING, List, Literal, Optional
+from dataclasses import dataclass
+
+from ..extras.constants import DATA_CONFIG
+from ..extras.misc import use_modelscope
+
+if TYPE_CHECKING:
+ from ..hparams import DataArguments
+
+
+@dataclass
+class DatasetAttr:
+
+ load_from: Literal["hf_hub", "ms_hub", "script", "file"]
+ dataset_name: Optional[str] = None
+ dataset_sha1: Optional[str] = None
+ subset: Optional[str] = None
+ folder: Optional[str] = None
+ ranking: Optional[bool] = False
+ formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
+
+ system: Optional[str] = None
+
+ prompt: Optional[str] = "instruction"
+ query: Optional[str] = "input"
+ response: Optional[str] = "output"
+ history: Optional[str] = None
+
+ messages: Optional[str] = "conversations"
+ tools: Optional[str] = None
+
+ role_tag: Optional[str] = "from"
+ content_tag: Optional[str] = "value"
+ user_tag: Optional[str] = "human"
+ assistant_tag: Optional[str] = "gpt"
+ observation_tag: Optional[str] = "observation"
+ function_tag: Optional[str] = "function_call"
+
+ def __repr__(self) -> str:
+ return self.dataset_name
+
+
+def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
+ dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
+ try:
+ with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
+ dataset_info = json.load(f)
+ except Exception as err:
+ if data_args.dataset is not None:
+ raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
+ dataset_info = None
+
+ if data_args.interleave_probs is not None:
+ data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
+
+ dataset_list: List[DatasetAttr] = []
+ for name in dataset_names:
+ if name not in dataset_info:
+ raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
+
+ has_hf_url = "hf_hub_url" in dataset_info[name]
+ has_ms_url = "ms_hub_url" in dataset_info[name]
+
+ if has_hf_url or has_ms_url:
+ if (use_modelscope() and has_ms_url) or (not has_hf_url):
+ dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
+ else:
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
+ elif "script_url" in dataset_info[name]:
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
+ else:
+ dataset_attr = DatasetAttr(
+ "file",
+ dataset_name=dataset_info[name]["file_name"],
+ dataset_sha1=dataset_info[name].get("file_sha1", None)
+ )
+
+ dataset_attr.subset = dataset_info[name].get("subset", None)
+ dataset_attr.folder = dataset_info[name].get("folder", None)
+ dataset_attr.ranking = dataset_info[name].get("ranking", False)
+ dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
+
+ if "columns" in dataset_info[name]:
+ if dataset_attr.formatting == "alpaca":
+ column_names = ["prompt", "query", "response", "history"]
+ else:
+ column_names = ["messages", "tools"]
+
+ column_names += ["system"]
+ for column_name in column_names:
+ setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
+
+ if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
+ for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
+ setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
+
+ dataset_list.append(dataset_attr)
+
+ return dataset_list
diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py
index 6f98c8f5..81caeda7 100644
--- a/src/llmtuner/data/preprocess.py
+++ b/src/llmtuner/data/preprocess.py
@@ -1,272 +1,241 @@
-import os
-import tiktoken
+from functools import partial
from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
-from llmtuner.data.template import get_template_and_fix_tokenizer
-from llmtuner.extras.constants import IGNORE_INDEX
-from llmtuner.extras.logging import get_logger
+from ..extras.constants import IGNORE_INDEX
+from ..extras.logging import get_logger
if TYPE_CHECKING:
- from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
- from llmtuner.hparams import DataArguments
+
+ from ..hparams import DataArguments
+ from .template import Template
logger = get_logger(__name__)
-def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
- for i in range(len(examples["prompt"])):
- query, response = examples["prompt"][i], examples["response"][i]
- query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
- history = examples["history"][i] if "history" in examples else None
- system = examples["system"][i] if "system" in examples else None
- yield query, response, history, system
-
-
-def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
- max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
- max_target_len = max(max_target_len, data_args.reserved_label_len)
- max_source_len = data_args.cutoff_len - max_target_len
- return max_source_len, max_target_len
-
-
-def preprocess_dataset(
- dataset: Union["Dataset", "IterableDataset"],
+def preprocess_pretrain_dataset(
+ examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
+ data_args: "DataArguments"
+) -> Dict[str, List[List[int]]]:
+ # build grouped texts with format `X1 X2 X3 ...`
+ text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
+ tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
+ for i in range(len(tokenized_examples["input_ids"])):
+ tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
+ tokenized_examples["attention_mask"][i] += [1]
+
+ concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
+ total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
+ block_size = data_args.cutoff_len
+ # we drop the small remainder, and if the total_length < block_size, we exclude this batch
+ total_length = (total_length // block_size) * block_size
+ # split by chunks of cutoff_len
+ result = {
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ return result
+
+
+def preprocess_supervised_dataset(
+ examples: Dict[str, List[Any]],
+ tokenizer: "PreTrainedTokenizer",
+ template: "Template",
data_args: "DataArguments",
- training_args: "Seq2SeqTrainingArguments",
- stage: Literal["pt", "sft", "rm", "ppo"]
-) -> Union["Dataset", "IterableDataset"]:
- template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
+) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X Y ` and labels with format ` ... Y `
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
- if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
- return dataset # already preprocessed
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
+ continue
- if data_args.train_on_prompt and template.efficient_eos:
- raise ValueError("Current template does not support `train_on_prompt`.")
-
- def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
- # build grouped texts with format `X1 X2 X3 ...`
- if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
- kwargs = dict(allowed_special="all")
- else:
- kwargs = dict(add_special_tokens=True)
-
- if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
- add_eos_token_flag = getattr(tokenizer, "add_eos_token")
- setattr(tokenizer, "add_eos_token", True)
-
- tokenized_examples = tokenizer(examples["prompt"], **kwargs)
- concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
- total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
- block_size = data_args.cutoff_len
- # we drop the small remainder, and if the total_length < block_size, we exclude this batch
- total_length = (total_length // block_size) * block_size
- # split by chunks of cutoff_len
- result = {
- k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
- for k, t in concatenated_examples.items()
- }
- # make sure the saved tokenizer is the same as the original one
- if hasattr(tokenizer, "add_eos_token"):
- setattr(tokenizer, "add_eos_token", add_eos_token_flag)
- return result
-
- def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
- # build inputs with format ` X Y ` and labels with format ` ... Y `
- # for multiturn examples, we only mask the prompt part in each prompt-response pair.
- model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
-
- for query, response, history, system in construct_example(examples):
- if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
- continue
-
- input_ids, labels = [], []
- for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
- tokenizer, query, response, history, system
- )):
- source_len, target_len = len(source_ids), len(target_ids)
- max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
- if source_len > max_source_len:
- source_ids = source_ids[:max_source_len]
- if target_len > max_target_len:
- target_ids = target_ids[:max_target_len]
-
- if data_args.train_on_prompt:
- source_mask = source_ids
- elif turn_idx != 0 and template.efficient_eos:
- source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
- else:
- source_mask = [IGNORE_INDEX] * len(source_ids)
-
- input_ids += source_ids + target_ids
- labels += source_mask + target_ids
-
- if template.efficient_eos:
- input_ids += [tokenizer.eos_token_id]
- labels += [tokenizer.eos_token_id]
-
- if len(input_ids) > data_args.cutoff_len:
- input_ids = input_ids[:data_args.cutoff_len]
- labels = labels[:data_args.cutoff_len]
-
- model_inputs["input_ids"].append(input_ids)
- model_inputs["attention_mask"].append([1] * len(input_ids))
- model_inputs["labels"].append(labels)
-
- return model_inputs
-
- def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
- # build inputs with format ` X1 Y1 X2 Y2 `
- # and labels with format ` ... Y1 ... Y2 `
- model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+ messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
- for query, response, history, system in construct_example(examples):
- if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
- continue
+ for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
+ tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
+ )):
+ if data_args.train_on_prompt:
+ source_mask = source_ids
+ elif turn_idx != 0 and template.efficient_eos:
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
+ else:
+ source_mask = [IGNORE_INDEX] * len(source_ids)
- for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
- tokenizer, query, response, history, system
- )):
- if data_args.train_on_prompt:
- source_mask = source_ids
- elif turn_idx != 0 and template.efficient_eos:
- source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
- else:
- source_mask = [IGNORE_INDEX] * len(source_ids)
- input_ids += source_ids + target_ids
- labels += source_mask + target_ids
+ input_ids += source_ids + target_ids
+ labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
- total_length = len(input_ids)
- block_size = data_args.cutoff_len
- # we drop the small remainder, and if the total_length < block_size, we exclude this batch
- total_length = (total_length // block_size) * block_size
- # split by chunks of cutoff_len
- for i in range(0, total_length, block_size):
- model_inputs["input_ids"].append(input_ids[i: i + block_size])
- model_inputs["attention_mask"].append([1] * block_size)
- model_inputs["labels"].append(labels[i: i + block_size])
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
- return model_inputs
+ return model_inputs
- def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
- # build inputs with format ` X` and labels with format `Y `
- model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
- for query, response, history, system in construct_example(examples):
- if not (isinstance(query, str) and query != ""):
- continue
+def preprocess_packed_supervised_dataset(
+ examples: Dict[str, List[Any]],
+ tokenizer: "PreTrainedTokenizer",
+ template: "Template",
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X1 Y1 X2 Y2 `
+ # and labels with format ` ... Y1 ... Y2 `
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+ input_ids, labels = [], []
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
+ continue
- input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
+ messages = examples["prompt"][i] + examples["response"][i]
+ for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
+ tokenizer, messages, examples["system"][i], examples["tools"][i]
+ )):
+ if data_args.train_on_prompt:
+ source_mask = source_ids
+ elif turn_idx != 0 and template.efficient_eos:
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
+ else:
+ source_mask = [IGNORE_INDEX] * len(source_ids)
- if template.efficient_eos:
- labels += [tokenizer.eos_token_id]
+ input_ids += source_ids + target_ids
+ labels += source_mask + target_ids
- if len(input_ids) > data_args.cutoff_len:
- input_ids = input_ids[:data_args.cutoff_len]
- if len(labels) > data_args.cutoff_len:
- labels = labels[:data_args.cutoff_len]
+ if template.efficient_eos:
+ input_ids += [tokenizer.eos_token_id]
+ labels += [tokenizer.eos_token_id]
- model_inputs["input_ids"].append(input_ids)
- model_inputs["attention_mask"].append([1] * len(input_ids))
- model_inputs["labels"].append(labels)
+ total_length = len(input_ids)
+ block_size = data_args.cutoff_len
+ # we drop the small remainder, and if the total_length < block_size, we exclude this batch
+ total_length = (total_length // block_size) * block_size
+ # split by chunks of cutoff_len
+ for i in range(0, total_length, block_size):
+ model_inputs["input_ids"].append(input_ids[i: i + block_size])
+ model_inputs["attention_mask"].append([1] * block_size)
+ model_inputs["labels"].append(labels[i: i + block_size])
- return model_inputs
+ return model_inputs
- def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
- # build input pairs with format ` X`, `Y1 ` and `Y2 `
- model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
- for query, response, history, system in construct_example(examples):
- if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
- continue
- prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
- _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
+def preprocess_unsupervised_dataset(
+ examples: Dict[str, List[Any]],
+ tokenizer: "PreTrainedTokenizer",
+ template: "Template",
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X` and labels with format `Y `
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
- if template.efficient_eos:
- chosen_ids += [tokenizer.eos_token_id]
- rejected_ids += [tokenizer.eos_token_id]
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
+ continue
- source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
- max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
- if source_len > max_source_len:
- prompt_ids = prompt_ids[:max_source_len]
- if target_len > max_target_len:
- chosen_ids = chosen_ids[:max_target_len]
- rejected_ids = rejected_ids[:max_target_len]
-
- model_inputs["prompt_ids"].append(prompt_ids)
- model_inputs["chosen_ids"].append(chosen_ids)
- model_inputs["rejected_ids"].append(rejected_ids)
-
- return model_inputs
-
- def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
- print("input_ids:\n{}".format(example["input_ids"]))
- print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
- print("label_ids:\n{}".format(example["labels"]))
- print("labels:\n{}".format(
- tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
- ))
-
- def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
- print("prompt_ids:\n{}".format(example["prompt_ids"]))
- print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
- print("chosen_ids:\n{}".format(example["chosen_ids"]))
- print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
- print("rejected_ids:\n{}".format(example["rejected_ids"]))
- print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
-
- def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
- print("input_ids:\n{}".format(example["input_ids"]))
- print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
-
- if stage == "pt":
- preprocess_func = preprocess_pretrain_dataset
- print_function = print_unsupervised_dataset_example
- elif stage == "sft" and not training_args.predict_with_generate:
- preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
- print_function = print_supervised_dataset_example
- elif stage == "rm":
- preprocess_func = preprocess_pairwise_dataset
- print_function = print_pairwise_dataset_example
- else:
- preprocess_func = preprocess_unsupervised_dataset
- print_function = print_unsupervised_dataset_example
-
- with training_args.main_process_first(desc="dataset map pre-processing"):
- column_names = list(next(iter(dataset)).keys())
- kwargs = {}
- if not data_args.streaming:
- kwargs = dict(
- num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=(not data_args.overwrite_cache),
- desc="Running tokenizer on dataset"
- )
-
- dataset = dataset.map(
- preprocess_func,
- batched=True,
- remove_columns=column_names,
- **kwargs
+ messages = examples["prompt"][i] + examples["response"][i]
+ input_ids, labels = template.encode_oneturn(
+ tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
)
- if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
- if training_args.should_save:
- dataset.save_to_disk(data_args.cache_path)
- logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
+ if template.efficient_eos:
+ labels += [tokenizer.eos_token_id]
- if training_args.should_log:
- try:
- print_function(next(iter(dataset)))
- except StopIteration:
- raise RuntimeError("Empty dataset!")
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
- return dataset
+ return model_inputs
+
+
+def preprocess_pairwise_dataset(
+ examples: Dict[str, List[Any]],
+ tokenizer: "PreTrainedTokenizer",
+ template: "Template",
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build input pairs with format ` X`, `Y1 ` and `Y2 `
+ model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
+ continue
+
+ chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
+ rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
+
+ prompt_ids, chosen_ids = template.encode_oneturn(
+ tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
+ )
+ _, rejected_ids = template.encode_oneturn(
+ tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
+ )
+
+ if template.efficient_eos:
+ chosen_ids += [tokenizer.eos_token_id]
+ rejected_ids += [tokenizer.eos_token_id]
+
+ model_inputs["prompt_ids"].append(prompt_ids)
+ model_inputs["chosen_ids"].append(chosen_ids)
+ model_inputs["rejected_ids"].append(rejected_ids)
+
+ return model_inputs
+
+
+def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print("labels:\n{}".format(
+ tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
+ ))
+
+
+def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
+ print("prompt_ids:\n{}".format(example["prompt_ids"]))
+ print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
+ print("chosen_ids:\n{}".format(example["chosen_ids"]))
+ print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
+ print("rejected_ids:\n{}".format(example["rejected_ids"]))
+ print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
+
+
+def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+
+
+def get_preprocess_and_print_func(
+ tokenizer: "PreTrainedTokenizer",
+ template: "Template",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo"],
+) -> Tuple[Callable, Callable]:
+ if stage == "pt":
+ preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
+ elif stage == "sft" and not training_args.predict_with_generate:
+ if data_args.sft_packing:
+ preprocess_func = partial(
+ preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
+ )
+ else:
+ preprocess_func = partial(
+ preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
+ )
+
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
+ elif stage == "rm":
+ preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
+ print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
+ else:
+ preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
+
+ return preprocess_func, print_function
diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py
index 9fa4be57..f4537d86 100644
--- a/src/llmtuner/data/template.py
+++ b/src/llmtuner/data/template.py
@@ -1,8 +1,10 @@
-import tiktoken
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
+
+from ..extras.logging import get_logger
+from .utils import Role
+from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
-from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -14,28 +16,30 @@ logger = get_logger(__name__)
@dataclass
class Template:
- prefix: List[Union[str, Dict[str, str]]]
- prompt: List[Union[str, Dict[str, str]]]
+ format_user: Callable
+ format_assistant: Callable
+ format_system: Callable
+ format_tool: Callable
+ format_observation: Callable
+ format_function: Callable
system: str
- sep: List[Union[str, Dict[str, str]]]
+ separator: List[Union[str, Dict[str, str]]]
stop_words: List[str]
- use_history: bool
efficient_eos: bool
replace_eos: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
- query: str,
- resp: str,
- history: Optional[List[Tuple[str, str]]] = None,
- system: Optional[str] = None
+ messages: List[Dict[str, str]],
+ system: str,
+ tools: str,
+ cutoff_len: Optional[int] = 1_000_000
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
- system, history = self._format(query, resp, history, system)
- encoded_pairs = self._encode(tokenizer, system, history)
+ encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
@@ -46,109 +50,89 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
- query: str,
- resp: str,
- history: Optional[List[Tuple[str, str]]] = None,
- system: Optional[str] = None
+ messages: List[Dict[str, str]],
+ system: str,
+ tools: str,
+ cutoff_len: Optional[int] = 1_000_000
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
- system, history = self._format(query, resp, history, system)
- encoded_pairs = self._encode(tokenizer, system, history)
+ encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
return encoded_pairs
- def _format(
- self,
- query: str,
- resp: str,
- history: Optional[List[Tuple[str, str]]] = None,
- system: Optional[str] = None
- ) -> Tuple[str, List[Tuple[str, str]]]:
- r"""
- Aligns inputs to the standard format.
- """
- system = system or self.system # use system if provided
- history = history if (history and self.use_history) else []
- history = history + [(query, resp)]
- return system, history
-
- def _get_special_ids(
- self,
- tokenizer: "PreTrainedTokenizer"
- ) -> Tuple[List[int], List[int]]:
- if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
- bos_ids = [tokenizer.bos_token_id]
- else: # baichuan, gpt2, qwen, yi models have no bos token
- bos_ids = []
-
- if tokenizer.eos_token_id is None:
- raise ValueError("EOS token is required.")
-
- if self.efficient_eos:
- eos_ids = []
- else:
- eos_ids = [tokenizer.eos_token_id]
-
- return bos_ids, eos_ids
-
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
+ messages: List[Dict[str, str]],
system: str,
- history: List[Tuple[str, str]]
+ tools: str,
+ cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
- Turn 0: bos + prefix + sep + query resp + eos
- Turn t: sep + bos + query resp + eos
+ Turn 0: system + query resp + eos
+ Turn t: sep + query resp + eos
"""
- bos_ids, eos_ids = self._get_special_ids(tokenizer)
- sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
- encoded_pairs = []
- for turn_idx, (query, resp) in enumerate(history):
- if turn_idx == 0:
- prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
- if len(prefix_ids) != 0: # has prefix
- prefix_ids = bos_ids + prefix_ids + sep_ids
- else:
- prefix_ids = bos_ids
- else:
- prefix_ids = sep_ids + bos_ids
+ system = system or self.system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+ if i == 0 and (system or tools):
+ tool_text = self.format_tool(content=tools)[0] if tools else ""
+ elements += self.format_system(content=(system + tool_text))
+ elif i > 0 and i % 2 == 0:
+ elements += self.separator
+
+ if message["role"] == Role.USER:
+ elements += self.format_user(content=message["content"], idx=str(i // 2))
+ elif message["role"] == Role.ASSISTANT:
+ elements += self.format_assistant(content=message["content"])
+ elif message["role"] == Role.OBSERVATION:
+ elements += self.format_observation(content=message["content"])
+ elif message["role"] == Role.FUNCTION:
+ elements += self.format_function(content=message["content"])
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ # TODO: need to improve
+ encoded_pairs = []
+ total_length = 0
+ for i in range(0, len(encoded_messages), 2):
+ if total_length >= cutoff_len:
+ break
+
+ encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
+ total_length += len(encoded_messages[i])
+
+ encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
+ total_length += len(encoded_messages[i+1])
+ encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
- query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
- resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
- encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
- def _convert_inputs_to_ids(
+ def _convert_elements_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
- context: List[Union[str, Dict[str, str]]],
- system: Optional[str] = None,
- query: Optional[str] = None,
- idx: Optional[str] = None
+ elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
- Converts context to token ids.
+ Converts elements to token ids.
"""
- if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
- kwargs = dict(allowed_special="all")
- else:
- kwargs = dict(add_special_tokens=False)
-
token_ids = []
- for elem in context:
+ for elem in elements:
if isinstance(elem, str):
- elem = elem.replace("{{system}}", system, 1) if system is not None else elem
- elem = elem.replace("{{query}}", query, 1) if query is not None else elem
- elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
if len(elem) != 0:
- token_ids = token_ids + tokenizer.encode(elem, **kwargs)
+ token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ elif isinstance(elem, set):
+ if "bos_token" in elem and tokenizer.bos_token_id:
+ token_ids = token_ids + [tokenizer.bos_token_id]
+ elif "eos_token" in elem and tokenizer.eos_token_id:
+ token_ids = token_ids + [tokenizer.eos_token_id]
else:
- raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
+ raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
@@ -159,22 +143,52 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
+ messages: List[Dict[str, str]],
system: str,
- history: List[Tuple[str, str]]
+ tools: str,
+ cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
- Turn 0: bos + prefix + query resp + eos
- Turn t: bos + query resp + eos
+ Turn 0: system + query resp + eos
+ Turn t: sep + query resp + eos
"""
- bos_ids, eos_ids = self._get_special_ids(tokenizer)
+ system = system or self.system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+ system_text = ""
+ if i == 0 and (system or tools):
+ tool_text = self.format_tool(content=tools)[0] if tools else ""
+ system_text = self.format_system(content=(system + tool_text))[0]
+ elif i > 0 and i % 2 == 0:
+ elements += self.separator
+
+ if message["role"] == Role.USER:
+ elements += self.format_user(content=system_text + message["content"], idx=str(i // 2))
+ elif message["role"] == Role.ASSISTANT:
+ elements += self.format_assistant(content=message["content"])
+ elif message["role"] == Role.OBSERVATION:
+ elements += self.format_observation(content=message["content"])
+ elif message["role"] == Role.FUNCTION:
+ elements += self.format_function(content=message["content"])
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ # TODO: need to improve
encoded_pairs = []
- for turn_idx, (query, resp) in enumerate(history):
- if turn_idx == 0: # llama2 template has no sep_ids
- query = self.prefix[0].replace("{{system}}", system) + query
- query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
- resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
- encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
+ total_length = 0
+ for i in range(0, len(encoded_messages), 2):
+ if total_length >= cutoff_len:
+ break
+
+ encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
+ total_length += len(encoded_messages[i])
+
+ encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
+ total_length += len(encoded_messages[i+1])
+ encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
+
return encoded_pairs
@@ -183,23 +197,33 @@ templates: Dict[str, Template] = {}
def register_template(
name: str,
- prefix: List[Union[str, Dict[str, str]]],
- prompt: List[Union[str, Dict[str, str]]],
- system: str,
- sep: List[Union[str, Dict[str, str]]],
+ format_user: Optional[Callable] = None,
+ format_assistant: Optional[Callable] = None,
+ format_system: Optional[Callable] = None,
+ format_tool: Optional[Callable] = None,
+ format_observation: Optional[Callable] = None,
+ format_function: Optional[Callable] = None,
+ system: Optional[str] = "",
+ separator: Optional[List[Union[str, Dict[str, str]]]] = "",
stop_words: Optional[List[str]] = [],
- use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if name.startswith("llama2") else Template
templates[name] = template_class(
- prefix=prefix,
- prompt=prompt,
+ format_user=format_user or StringFormatter(container=["{{content}}"]),
+ format_assistant=format_assistant or StringFormatter(container=[
+ "{{content}}", {"eos_token"}
+ ]),
+ format_system=format_system or StringFormatter(container=["{{content}}"]),
+ format_tool=format_tool or ToolFormatter(type="default"),
+ format_observation=format_observation or format_user,
+ format_function=format_function or FunctionFormatter(container=[
+ "Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
+ ]),
system=system,
- sep=sep,
+ separator=separator,
stop_words=stop_words,
- use_history=use_history,
efficient_eos=efficient_eos,
replace_eos=replace_eos
)
@@ -244,17 +268,14 @@ def get_template_and_fix_tokenizer(
register_template(
name="alpaca",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "### Instruction:\n{{query}}\n\n### Response:\n"
- ],
+ format_user=StringFormatter(container=[
+ "### Instruction:\n{{content}}\n\n### Response:\n"
+ ]),
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
- sep=[
+ separator=[
"\n\n"
]
)
@@ -262,17 +283,14 @@ register_template(
register_template(
name="aquila",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "Human: {{query}}###Assistant:"
- ],
+ format_user=StringFormatter(container=[
+ "Human: {{content}}###Assistant:"
+ ]),
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
- sep=[
+ separator=[
"###"
],
stop_words=[
@@ -284,46 +302,32 @@ register_template(
register_template(
name="baichuan",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- {"token": ""}, # user token
- "{{query}}",
- {"token": ""} # assistant token
- ],
- system="",
- sep=[],
+ format_user=StringFormatter(container=[
+ {"token": ""},
+ "{{content}}",
+ {"token": ""}
+ ]),
efficient_eos=True
)
register_template(
name="baichuan2",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- {"token": ""}, # user token
- "{{query}}",
- {"token": ""} # assistant token
- ],
- system="",
- sep=[],
+ format_user=StringFormatter(container=[
+ {"token": ""},
+ "{{content}}",
+ {"token": ""}
+ ]),
efficient_eos=True
)
register_template(
name="belle",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "Human: {{query}}\n\nBelle: "
- ],
- system="",
- sep=[
+ format_user=StringFormatter(container=[
+ "Human: {{content}}\n\nBelle: "
+ ]),
+ separator=[
"\n\n"
]
)
@@ -331,31 +335,25 @@ register_template(
register_template(
name="bluelm",
- prefix=[
- "{{system}}"
- ],
- prompt=[
+ format_user=StringFormatter(container=[
{"token": "[|Human|]:"},
- "{{query}}",
+ "{{content}}",
{"token": "[|AI|]:"}
- ],
- system="",
- sep=[]
+ ])
)
register_template(
name="chatglm2",
- prefix=[
+ format_user=StringFormatter(container=[
+ "[Round {{idx}}]\n\n问:{{content}}\n\n答:"
+ ]),
+ format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
- "{{system}}"
- ],
- prompt=[
- "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
- ],
- system="",
- sep=[
+ "{{content}}"
+ ]),
+ separator=[
"\n\n"
],
efficient_eos=True
@@ -364,53 +362,35 @@ register_template(
register_template(
name="chatglm3",
- prefix=[
- {"token": "[gMASK]"},
- {"token": "sop"},
- {"token": "<|system|>"},
- "\n",
- "{{system}}"
- ],
- prompt=[
+ format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n",
- "{{query}}",
- {"token": "<|assistant|>"},
- "\n" # add an extra newline to avoid error in ChatGLM's process_response method
- ],
- system=(
- "You are ChatGLM3, a large language model trained by Zhipu.AI. "
- "Follow the user's instructions carefully. Respond using markdown."
- ),
- sep=[],
- stop_words=[
- "<|user|>",
- "<|observation|>"
- ],
- efficient_eos=True
-)
-
-
-register_template(
- name="chatglm3_raw", # the raw template for tool tuning
- prefix=[
- {"token": "[gMASK]"},
- {"token": "sop"},
- {"token": "<|system|>"},
- "\n",
- "{{system}}"
- ],
- prompt=[
- {"token": "<|user|>"},
- "\n",
- "{{query}}",
+ "{{content}}",
{"token": "<|assistant|>"}
- ],
+ ]),
+ format_assistant=StringFormatter(container=[
+ "\n"
+ "{{content}}"
+ ]),
+ format_system=StringFormatter(container=[
+ {"token": "[gMASK]"},
+ {"token": "sop"},
+ {"token": "<|system|>"},
+ "\n",
+ "{{content}}"
+ ]),
+ format_observation=StringFormatter(container=[
+ {"token": "<|observation|>"},
+ "\n",
+ "{{content}}"
+ ]),
+ format_function=FunctionFormatter(container=[
+ "{{name}}\n{{arguments}}"
+ ]),
system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
- sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
@@ -421,47 +401,34 @@ register_template(
register_template(
name="codegeex2",
- prefix=[
+ format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
- "{{system}}"
- ],
- prompt=[
- "{{query}}"
- ],
- system="",
- sep=[]
+ "{{content}}"
+ ])
)
register_template(
name="deepseek",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "User: {{query}}\n\nAssistant:"
- ],
- system="",
- sep=[]
+ format_user=StringFormatter(container=[
+ "User: {{content}}\n\nAssistant:"
+ ])
)
register_template(
name="deepseekcoder",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "### Instruction:\n{{query}}\n### Response:\n"
- ],
+ format_user=StringFormatter(container=[
+ "### Instruction:\n{{content}}\n### Response:\n"
+ ]),
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
- sep=[
+ separator=[
"\n",
{"token": "<|EOT|>"},
"\n"
@@ -475,17 +442,14 @@ register_template(
register_template(
name="default",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "Human: {{query}}\nAssistant:"
- ],
+ format_user=StringFormatter(container=[
+ "Human: {{content}}\nAssistant: "
+ ]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
),
- sep=[
+ separator=[
"\n"
]
)
@@ -493,14 +457,10 @@ register_template(
register_template(
name="falcon",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "User: {{query}}\nFalcon:"
- ],
- system="",
- sep=[
+ format_user=StringFormatter(container=[
+ "User: {{content}}\nFalcon:"
+ ]),
+ separator=[
"\n"
],
efficient_eos=True
@@ -509,16 +469,12 @@ register_template(
register_template(
name="intern",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "<|User|>:{{query}}",
+ format_user=StringFormatter(container=[
+ "<|User|>:{{content}}",
{"token": ""},
"\n<|Bot|>:"
- ],
- system="",
- sep=[
+ ]),
+ separator=[
{"token": ""},
"\n"
],
@@ -529,14 +485,44 @@ register_template(
)
+register_template(
+ name="intern2",
+ format_user=StringFormatter(container=[
+ {"token": "[UNUSED_TOKEN_146]"},
+ "user\n{{content}}",
+ {"token": "[UNUSED_TOKEN_145]"},
+ "\n",
+ {"token": "[UNUSED_TOKEN_146]"},
+ "assistant\n"
+ ]),
+ format_system=StringFormatter(container=[
+ {"token": "[UNUSED_TOKEN_146]"},
+ "system\n{{content}}",
+ {"token": "[UNUSED_TOKEN_145]"},
+ "\n"
+ ]),
+ system=(
+ "You are an AI assistant whose name is InternLM (书生·浦语).\n"
+ "- InternLM (书生·浦语) is a conversational language model that is developed "
+ "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
+ "by the user such as English and 中文."
+ ),
+ separator=[
+ {"token": "[UNUSED_TOKEN_145]"},
+ "\n"
+ ],
+ stop_words=[
+ "[UNUSED_TOKEN_145]"
+ ],
+ efficient_eos=True
+)
+
+
register_template(
name="llama2",
- prefix=[
- "<>\n{{system}}\n<>\n\n"
- ],
- prompt=[
- "[INST] {{query}} [/INST]"
- ],
+ format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
+ format_system=StringFormatter(container=["<>\n{{content}}\n<>\n\n"]),
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
@@ -546,49 +532,32 @@ register_template(
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
- ),
- sep=[]
+ )
)
register_template(
name="llama2_zh",
- prefix=[
- "<>\n{{system}}\n<>\n\n"
- ],
- prompt=[
- "[INST] {{query}} [/INST]"
- ],
- system="You are a helpful assistant. 你是一个乐于助人的助手。",
- sep=[]
+ format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
+ format_system=StringFormatter(container=["<>\n{{content}}\n<>\n\n"]),
+ system="You are a helpful assistant. 你是一个乐于助人的助手。"
)
register_template(
name="mistral",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "[INST] {{query}} [/INST]"
- ],
- system="",
- sep=[]
+ format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
)
register_template(
name="openchat",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "GPT4 Correct User: {{query}}",
+ format_user=StringFormatter(container=[
+ "GPT4 Correct User: {{content}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
- ],
- system="",
- sep=[
+ ]),
+ separator=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
@@ -600,14 +569,14 @@ register_template(
register_template(
name="qwen",
- prefix=[
- "<|im_start|>system\n{{system}}<|im_end|>"
- ],
- prompt=[
- "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
- ],
+ format_user=StringFormatter(container=[
+ "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
+ ]),
+ format_system=StringFormatter(container=[
+ "<|im_start|>system\n{{content}}<|im_end|>\n"
+ ]),
system="You are a helpful assistant.",
- sep=[
+ separator=[
"\n"
],
stop_words=[
@@ -619,32 +588,28 @@ register_template(
register_template(
name="solar",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "### User:\n{{query}}\n\n### Assistant:\n"
- ],
- system="",
- sep=[]
+ format_user=StringFormatter(container=[
+ "### User:\n{{content}}\n\n### Assistant:\n"
+ ])
)
register_template(
name="starchat",
- prefix=[
- {"token": "<|system|>"},
- "\n{{system}}",
- ],
- prompt=[
+ format_user=StringFormatter(container=[
{"token": "<|user|>"},
- "\n{{query}}",
+ "\n{{content}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
- ],
- system="",
- sep=[
+ ]),
+ format_system=StringFormatter(container=[
+ {"token": "<|system|>"},
+ "\n{{content}}",
+ {"token": "<|end|>"},
+ "\n"
+ ]),
+ separator=[
{"token": "<|end|>"},
"\n"
],
@@ -656,75 +621,55 @@ register_template(
register_template(
- name="vanilla",
- prefix=[],
- prompt=[
- "{{query}}"
- ],
- system="",
- sep=[],
- use_history=False
+ name="vanilla"
)
register_template(
name="vicuna",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "USER: {{query}} ASSISTANT:"
- ],
+ format_user=StringFormatter(container=[
+ "USER: {{content}} ASSISTANT:"
+ ]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
- ),
- sep=[]
+ )
)
register_template(
name="xuanyuan",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "Human: {{query}} Assistant:"
- ],
+ format_user=StringFormatter(container=[
+ "Human: {{content}} Assistant:"
+ ]),
system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
- ),
- sep=[]
+ )
)
register_template(
name="xverse",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "Human: {{query}}\n\nAssistant: "
- ],
- system="",
- sep=[]
+ format_user=StringFormatter(container=[
+ "Human: {{content}}\n\nAssistant: "
+ ])
)
register_template(
name="yayi",
- prefix=[
- {"token": "<|System|>"},
- ":\n{{system}}"
- ],
- prompt=[
+ format_user=StringFormatter(container=[
{"token": "<|Human|>"},
- ":\n{{query}}\n\n",
+ ":\n{{content}}\n\n",
{"token": "<|YaYi|>"},
":"
- ],
+ ]),
+ format_system=StringFormatter(container=[
+ {"token": "<|System|>"},
+ ":\n{{content}}\n\n"
+ ]),
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
@@ -736,7 +681,7 @@ register_template(
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
- sep=[
+ separator=[
"\n\n"
],
stop_words=[
@@ -747,14 +692,10 @@ register_template(
register_template(
name="yi",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
- ],
- system="",
- sep=[
+ format_user=StringFormatter(container=[
+ "<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
+ ]),
+ separator=[
"\n"
],
stop_words=[
@@ -766,15 +707,11 @@ register_template(
register_template(
name="yuan",
- prefix=[
- "{{system}}"
- ],
- prompt=[
- "{{query}}",
+ format_user=StringFormatter(container=[
+ "{{content}}",
{"token": ""}
- ],
- system="",
- sep=[
+ ]),
+ separator=[
"\n"
],
stop_words=[
@@ -786,30 +723,25 @@ register_template(
register_template(
name="zephyr",
- prefix=[
- "<|system|>\n{{system}}",
- ],
- prompt=[
- "<|user|>\n{{query}}<|assistant|>"
- ],
- system="You are a friendly chatbot who always responds in the style of a pirate",
- sep=[]
+ format_user=StringFormatter(container=[
+ "<|user|>\n{{content}}<|assistant|>"
+ ]),
+ format_system=StringFormatter(container=[
+ "<|system|>\n{{content}}",
+ ]),
+ system="You are a friendly chatbot who always responds in the style of a pirate"
)
register_template(
name="ziya",
- prefix=[
- "{{system}}"
- ],
- prompt=[
+ format_user=StringFormatter(container=[
{"token": ""},
- ":{{query}}\n",
+ ":{{content}}\n",
{"token": ""},
":"
- ],
- system="",
- sep=[
+ ]),
+ separator=[
"\n"
]
)
diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py
index 9dfe4dc3..106e87a7 100644
--- a/src/llmtuner/data/utils.py
+++ b/src/llmtuner/data/utils.py
@@ -1,7 +1,8 @@
import hashlib
-from typing import TYPE_CHECKING, Dict, List, Optional, Union
+from enum import Enum, unique
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
-from llmtuner.extras.logging import get_logger
+from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@@ -12,6 +13,14 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
+@unique
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ OBSERVATION = "observation"
+ FUNCTION = "function"
+
+
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
@@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
+def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
+ max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
+ max_target_len = max(max_target_len, data_args.reserved_label_len)
+ max_source_len = data_args.cutoff_len - max_target_len
+ return max_source_len, max_target_len
+
+
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
diff --git a/src/llmtuner/eval/__init__.py b/src/llmtuner/eval/__init__.py
index a7c9a127..95ce0377 100644
--- a/src/llmtuner/eval/__init__.py
+++ b/src/llmtuner/eval/__init__.py
@@ -1 +1,4 @@
-from llmtuner.eval.evaluator import Evaluator
+from .evaluator import Evaluator
+
+
+__all__ = ["Evaluator"]
diff --git a/src/llmtuner/eval/evaluator.py b/src/llmtuner/eval/evaluator.py
index 0bf4c3f4..1cb55b38 100644
--- a/src/llmtuner/eval/evaluator.py
+++ b/src/llmtuner/eval/evaluator.py
@@ -3,7 +3,6 @@
import os
import json
import torch
-import tiktoken
import numpy as np
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
@@ -11,10 +10,11 @@ from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
-from llmtuner.data.template import get_template_and_fix_tokenizer
-from llmtuner.eval.template import get_eval_template
-from llmtuner.extras.constants import CHOICES, SUBJECTS
-from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
+from ..data import get_template_and_fix_tokenizer
+from .template import get_eval_template
+from ..extras.constants import CHOICES, SUBJECTS
+from ..hparams import get_eval_args
+from ..model import dispatch_model, load_model_and_tokenizer
class Evaluator:
@@ -26,15 +26,9 @@ class Evaluator:
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
- self.choice_inputs = self._encode_choices()
-
- def _encode_choices(self) -> List[int]:
- if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
- kwargs = dict(allowed_special="all")
- else:
- kwargs = dict(add_special_tokens=False)
-
- return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
+ self.choice_inputs = [self.tokenizer.encode(
+ self.eval_template.prefix + ch, add_special_tokens=False
+ )[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
@@ -71,17 +65,17 @@ class Evaluator:
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
- query, resp, history = self.eval_template.format_example(
+ messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
- subject_name=categorys[subject]["name"],
- use_history=self.template.use_history
+ subject_name=categorys[subject]["name"]
)
+
input_ids, _ = self.template.encode_oneturn(
- tokenizer=self.tokenizer, query=query, resp=resp, history=history
+ tokenizer=self.tokenizer, messages=messages
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
- labels.append(resp)
+ labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad(
diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py
index 2251ad57..5514e5d5 100644
--- a/src/llmtuner/eval/template.py
+++ b/src/llmtuner/eval/template.py
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
-from llmtuner.extras.constants import CHOICES
+from ..extras.constants import CHOICES
+from ..data import Role
if TYPE_CHECKING:
from datasets import Dataset
@@ -28,23 +29,26 @@ class EvalTemplate:
support_set: "Dataset",
subject_name: str,
use_history: bool
- ) -> Tuple[str, str, List[Tuple[str, str]]]:
- query, resp = self.parse_example(target_data)
- history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
+ ) -> List[Dict[str, str]]:
+ messages = []
+ for k in range(len(support_set)):
+ prompt, response = self.parse_example(support_set[k])
+ messages.append({"role": Role.USER, "content": prompt})
+ messages.append({"role": Role.ASSISTANT, "content": response})
- if len(history):
- temp = history.pop(0)
- history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
- else:
- query = self.system.format(subject=subject_name) + query
+ prompt, response = self.parse_example(target_data)
+ messages.append({"role": Role.USER, "content": prompt})
+ messages.append({"role": Role.ASSISTANT, "content": response})
+
+ messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history:
- query = "\n\n".join(["".join(item) for item in history] + [query])
- history = []
- return query.strip(), resp, history
+ messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
+
+ return messages
-eval_templates: Dict[str, EvalTemplate] = {}
+eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(
@@ -62,7 +66,7 @@ def register_eval_template(
)
-def get_eval_template(name: str) -> EvalTemplate:
+def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template
diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py
index 17ab5dc1..b97d0168 100644
--- a/src/llmtuner/extras/callbacks.py
+++ b/src/llmtuner/extras/callbacks.py
@@ -6,9 +6,9 @@ from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
-from llmtuner.extras.constants import LOG_FILE_NAME
-from llmtuner.extras.logging import get_logger
-from llmtuner.extras.misc import fix_valuehead_checkpoint
+from .constants import LOG_FILE_NAME
+from .logging import get_logger
+from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py
index 546e3d5f..2c9fac93 100644
--- a/src/llmtuner/extras/constants.py
+++ b/src/llmtuner/extras/constants.py
@@ -5,6 +5,8 @@ from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
+DATA_CONFIG = "dataset_info.json"
+
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
@@ -339,6 +341,30 @@ register_model_group(
)
+register_model_group(
+ models={
+ "InternLM2-7B": {
+ DownloadSource.DEFAULT: "internlm/internlm2-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b"
+ },
+ "InternLM2-20B": {
+ DownloadSource.DEFAULT: "internlm/internlm2-20b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b"
+ },
+ "InternLM2-7B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b"
+ },
+ "InternLM2-20B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b"
+ }
+ },
+ module="wqkv",
+ template="intern2"
+)
+
+
register_model_group(
models={
"LingoWhale-8B": {
diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py
index dee101ec..2a1199e4 100644
--- a/src/llmtuner/extras/misc.py
+++ b/src/llmtuner/extras/misc.py
@@ -13,8 +13,8 @@ from transformers.utils import (
)
from peft import PeftModel
-from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
-from llmtuner.extras.logging import get_logger
+from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
+from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py
index 1fb7ed3b..a9f5da28 100644
--- a/src/llmtuner/extras/patches/llama_patch.py
+++ b/src/llmtuner/extras/patches/llama_patch.py
@@ -10,7 +10,7 @@ try:
except ImportError:
print("Please upgrade `transformers`.")
-from llmtuner.extras.packages import is_flash_attn2_available
+from ..packages import is_flash_attn2_available
if is_flash_attn2_available():
diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py
index cf2c72ac..65b3bf42 100644
--- a/src/llmtuner/extras/ploting.py
+++ b/src/llmtuner/extras/ploting.py
@@ -4,8 +4,8 @@ import json
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
-from llmtuner.extras.logging import get_logger
-from llmtuner.extras.packages import is_matplotlib_available
+from .logging import get_logger
+from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
diff --git a/src/llmtuner/hparams/__init__.py b/src/llmtuner/hparams/__init__.py
index 623d6517..80deeb72 100644
--- a/src/llmtuner/hparams/__init__.py
+++ b/src/llmtuner/hparams/__init__.py
@@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
+from .parser import get_train_args, get_infer_args, get_eval_args
+
+
+__all__ = [
+ "DataArguments",
+ "EvaluationArguments",
+ "FinetuningArguments",
+ "GeneratingArguments",
+ "ModelArguments",
+ "get_train_args",
+ "get_infer_args",
+ "get_eval_args"
+]
diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py
index 7be4f4f5..a635e47a 100644
--- a/src/llmtuner/hparams/data_args.py
+++ b/src/llmtuner/hparams/data_args.py
@@ -1,40 +1,7 @@
-import os
-import json
-from typing import List, Literal, Optional
+from typing import Literal, Optional
from dataclasses import dataclass, field
-DATA_CONFIG = "dataset_info.json"
-
-
-def use_modelscope() -> bool:
- return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
-
-
-@dataclass
-class DatasetAttr:
-
- load_from: Literal["hf_hub", "ms_hub", "script", "file"]
- dataset_name: Optional[str] = None
- dataset_sha1: Optional[str] = None
- subset: Optional[str] = None
- folder: Optional[str] = None
- ranking: Optional[bool] = False
- formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
-
- prompt: Optional[str] = "instruction"
- query: Optional[str] = "input"
- response: Optional[str] = "output"
- history: Optional[str] = None
- messages: Optional[str] = "conversations"
- role: Optional[str] = "from"
- content: Optional[str] = "value"
- system: Optional[str] = None
-
- def __repr__(self) -> str:
- return self.dataset_name
-
-
@dataclass
class DataArguments:
r"""
@@ -126,64 +93,3 @@ class DataArguments:
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
-
- def init_for_training(self, seed: int): # support mixing multiple datasets
- self.seed = seed
- dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
- try:
- with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
- dataset_info = json.load(f)
- except Exception as err:
- if self.dataset is not None:
- raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
- dataset_info = None
-
- if self.interleave_probs is not None:
- self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
-
- self.dataset_list: List[DatasetAttr] = []
- for name in dataset_names:
- if name not in dataset_info:
- raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
-
- has_hf_url = "hf_hub_url" in dataset_info[name]
- has_ms_url = "ms_hub_url" in dataset_info[name]
-
- if has_hf_url or has_ms_url:
- if (use_modelscope() and has_ms_url) or (not has_hf_url):
- dataset_attr = DatasetAttr(
- "ms_hub",
- dataset_name=dataset_info[name]["ms_hub_url"]
- )
- else:
- dataset_attr = DatasetAttr(
- "hf_hub",
- dataset_name=dataset_info[name]["hf_hub_url"]
- )
- elif "script_url" in dataset_info[name]:
- dataset_attr = DatasetAttr(
- "script",
- dataset_name=dataset_info[name]["script_url"]
- )
- else:
- dataset_attr = DatasetAttr(
- "file",
- dataset_name=dataset_info[name]["file_name"],
- dataset_sha1=dataset_info[name].get("file_sha1", None)
- )
-
- if "columns" in dataset_info[name]:
- dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
- dataset_attr.query = dataset_info[name]["columns"].get("query", None)
- dataset_attr.response = dataset_info[name]["columns"].get("response", None)
- dataset_attr.history = dataset_info[name]["columns"].get("history", None)
- dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
- dataset_attr.role = dataset_info[name]["columns"].get("role", None)
- dataset_attr.content = dataset_info[name]["columns"].get("content", None)
- dataset_attr.system = dataset_info[name]["columns"].get("system", None)
-
- dataset_attr.subset = dataset_info[name].get("subset", None)
- dataset_attr.folder = dataset_info[name].get("folder", None)
- dataset_attr.ranking = dataset_info[name].get("ranking", False)
- dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
- self.dataset_list.append(dataset_attr)
diff --git a/src/llmtuner/hparams/evaluation_args.py b/src/llmtuner/hparams/evaluation_args.py
index 5f507698..c70103ed 100644
--- a/src/llmtuner/hparams/evaluation_args.py
+++ b/src/llmtuner/hparams/evaluation_args.py
@@ -43,13 +43,5 @@ class EvaluationArguments:
)
def __post_init__(self):
- task_available = []
- for folder in os.listdir(self.task_dir):
- if os.path.isdir(os.path.join(self.task_dir, folder)):
- task_available.append(folder)
-
- if self.task not in task_available:
- raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
-
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")
diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/hparams/parser.py
similarity index 97%
rename from src/llmtuner/model/parser.py
rename to src/llmtuner/hparams/parser.py
index f3626f69..cba9c690 100644
--- a/src/llmtuner/model/parser.py
+++ b/src/llmtuner/hparams/parser.py
@@ -8,14 +8,12 @@ from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
-from llmtuner.extras.logging import get_logger
-from llmtuner.hparams import (
- ModelArguments,
- DataArguments,
- EvaluationArguments,
- FinetuningArguments,
- GeneratingArguments
-)
+from ..extras.logging import get_logger
+from .data_args import DataArguments
+from .evaluation_args import EvaluationArguments
+from .finetuning_args import FinetuningArguments
+from .generating_args import GeneratingArguments
+from .model_args import ModelArguments
logger = get_logger(__name__)
@@ -107,8 +105,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
_set_transformers_logging()
# Check arguments
- data_args.init_for_training(training_args.seed)
-
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py
index f12acb58..6d598361 100644
--- a/src/llmtuner/model/__init__.py
+++ b/src/llmtuner/model/__init__.py
@@ -1,5 +1,5 @@
-# Level: loader > adapter > parser, utils
+from .loader import load_model_and_tokenizer
+from .utils import dispatch_model, get_modelcard_args, load_valuehead_params
-from llmtuner.model.loader import load_model_and_tokenizer
-from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
-from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params
+
+__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"]
diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py
index 261650b7..f0d7ce21 100644
--- a/src/llmtuner/model/adapter.py
+++ b/src/llmtuner/model/adapter.py
@@ -3,12 +3,12 @@ from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
-from llmtuner.extras.logging import get_logger
-from llmtuner.model.utils import find_all_linear_modules
+from ..extras.logging import get_logger
+from .utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
- from llmtuner.hparams import ModelArguments, FinetuningArguments
+ from ..hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py
index 8cdf85bf..adc45ea8 100644
--- a/src/llmtuner/model/loader.py
+++ b/src/llmtuner/model/loader.py
@@ -4,15 +4,15 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
-from llmtuner.extras.logging import get_logger
-from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
-from llmtuner.model.adapter import init_adapter
-from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
-from llmtuner.model.utils import load_valuehead_params, register_autoclass
+from ..extras.logging import get_logger
+from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
+from .adapter import init_adapter
+from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
+from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
- from llmtuner.hparams import ModelArguments, FinetuningArguments
+ from ..hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py
index 381436d2..d21d87dc 100644
--- a/src/llmtuner/model/patcher.py
+++ b/src/llmtuner/model/patcher.py
@@ -10,15 +10,15 @@ from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTra
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
-from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
-from llmtuner.extras.logging import get_logger
-from llmtuner.extras.misc import get_current_device, infer_optim_dtype
-from llmtuner.extras.packages import is_flash_attn2_available
+from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
+from ..extras.logging import get_logger
+from ..extras.misc import get_current_device, infer_optim_dtype
+from ..extras.packages import is_flash_attn2_available
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
- from llmtuner.hparams import ModelArguments
+ from ..hparams import ModelArguments
logger = get_logger(__name__)
diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py
index 14bd4c59..ba4478fb 100644
--- a/src/llmtuner/model/utils.py
+++ b/src/llmtuner/model/utils.py
@@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any, Dict, List
from transformers import PreTrainedModel
from transformers.utils import cached_file
-from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
-from llmtuner.extras.logging import get_logger
-from llmtuner.extras.misc import get_current_device
+from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
+from ..extras.logging import get_logger
+from ..extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+ from ..hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__)
diff --git a/src/llmtuner/train/__init__.py b/src/llmtuner/train/__init__.py
index e57c163b..6c22bc15 100644
--- a/src/llmtuner/train/__init__.py
+++ b/src/llmtuner/train/__init__.py
@@ -1 +1,4 @@
-from llmtuner.train.tuner import export_model, run_exp
+from .tuner import export_model, run_exp
+
+
+__all__ = ["export_model", "run_exp"]
diff --git a/src/llmtuner/train/dpo/__init__.py b/src/llmtuner/train/dpo/__init__.py
index 96c8ed09..43fe9420 100644
--- a/src/llmtuner/train/dpo/__init__.py
+++ b/src/llmtuner/train/dpo/__init__.py
@@ -1 +1,4 @@
-from llmtuner.train.dpo.workflow import run_dpo
+from .workflow import run_dpo
+
+
+__all__ = ["run_dpo"]
diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py
index 97d80353..b5a44f5e 100644
--- a/src/llmtuner/train/dpo/trainer.py
+++ b/src/llmtuner/train/dpo/trainer.py
@@ -5,7 +5,7 @@ from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
-from llmtuner.extras.constants import IGNORE_INDEX
+from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedModel
diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py
index 12a6b545..bd61a308 100644
--- a/src/llmtuner/train/dpo/workflow.py
+++ b/src/llmtuner/train/dpo/workflow.py
@@ -3,18 +3,18 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
-from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
-from llmtuner.extras.constants import IGNORE_INDEX
-from llmtuner.extras.ploting import plot_loss
-from llmtuner.hparams import ModelArguments
-from llmtuner.model import load_model_and_tokenizer
-from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
-from llmtuner.train.dpo.trainer import CustomDPOTrainer
-from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
+from ...data import get_dataset, split_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.ploting import plot_loss
+from ...hparams import ModelArguments
+from ...model import load_model_and_tokenizer
+from ...train.dpo.collator import DPODataCollatorWithPadding
+from ...train.dpo.trainer import CustomDPOTrainer
+from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING:
from transformers import TrainerCallback
- from llmtuner.hparams import DataArguments, FinetuningArguments
+ from ...hparams import DataArguments, FinetuningArguments
def run_dpo(
@@ -24,9 +24,8 @@ def run_dpo(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
- dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
- dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
+ dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
diff --git a/src/llmtuner/train/ppo/__init__.py b/src/llmtuner/train/ppo/__init__.py
index c32b23fa..d17336d5 100644
--- a/src/llmtuner/train/ppo/__init__.py
+++ b/src/llmtuner/train/ppo/__init__.py
@@ -1 +1,4 @@
-from llmtuner.train.ppo.workflow import run_ppo
+from .workflow import run_ppo
+
+
+__all__ = ["run_ppo"]
diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py
index 31cab7c0..8b2116ea 100644
--- a/src/llmtuner/train/ppo/trainer.py
+++ b/src/llmtuner/train/ppo/trainer.py
@@ -13,15 +13,15 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
-from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback
-from llmtuner.extras.logging import get_logger
-from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
-from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
+from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
+from ...extras.logging import get_logger
+from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
+from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
- from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
+ from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
diff --git a/src/llmtuner/train/ppo/utils.py b/src/llmtuner/train/ppo/utils.py
index 12e9bfcb..44e62067 100644
--- a/src/llmtuner/train/ppo/utils.py
+++ b/src/llmtuner/train/ppo/utils.py
@@ -2,7 +2,7 @@ import json
import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
-from llmtuner.extras.packages import is_requests_available
+from ...extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel
diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py
index 10c6a227..7b0dcc53 100644
--- a/src/llmtuner/train/ppo/workflow.py
+++ b/src/llmtuner/train/ppo/workflow.py
@@ -7,17 +7,17 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
-from llmtuner.data import get_dataset, preprocess_dataset
-from llmtuner.extras.callbacks import FixValueHeadModelCallback
-from llmtuner.extras.misc import fix_valuehead_checkpoint
-from llmtuner.extras.ploting import plot_loss
-from llmtuner.model import load_model_and_tokenizer
-from llmtuner.train.utils import create_ref_model, create_reward_model
-from llmtuner.train.ppo.trainer import CustomPPOTrainer
+from ...data import get_dataset
+from ...extras.callbacks import FixValueHeadModelCallback
+from ...extras.misc import fix_valuehead_checkpoint
+from ...extras.ploting import plot_loss
+from ...model import load_model_and_tokenizer
+from ...train.utils import create_ref_model, create_reward_model
+from ...train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
+ from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo(
@@ -28,9 +28,8 @@ def run_ppo(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
- dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
- dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
+ dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
diff --git a/src/llmtuner/train/pt/__init__.py b/src/llmtuner/train/pt/__init__.py
index eacbeadb..bdf397f6 100644
--- a/src/llmtuner/train/pt/__init__.py
+++ b/src/llmtuner/train/pt/__init__.py
@@ -1 +1,4 @@
-from llmtuner.train.pt.workflow import run_pt
+from .workflow import run_pt
+
+
+__all__ = ["run_pt"]
diff --git a/src/llmtuner/train/pt/workflow.py b/src/llmtuner/train/pt/workflow.py
index 27a6d2c4..3b7267eb 100644
--- a/src/llmtuner/train/pt/workflow.py
+++ b/src/llmtuner/train/pt/workflow.py
@@ -4,14 +4,14 @@ import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer
-from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
-from llmtuner.extras.ploting import plot_loss
-from llmtuner.model import load_model_and_tokenizer
-from llmtuner.train.utils import create_modelcard_and_push
+from ...data import get_dataset, split_dataset
+from ...extras.ploting import plot_loss
+from ...model import load_model_and_tokenizer
+from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+ from ...hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt(
@@ -21,9 +21,8 @@ def run_pt(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
- dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
- dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
+ dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
diff --git a/src/llmtuner/train/rm/__init__.py b/src/llmtuner/train/rm/__init__.py
index c80ccfb9..dedac35f 100644
--- a/src/llmtuner/train/rm/__init__.py
+++ b/src/llmtuner/train/rm/__init__.py
@@ -1 +1,4 @@
-from llmtuner.train.rm.workflow import run_rm
+from .workflow import run_rm
+
+
+__all__ = ["run_rm"]
diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py
index b018a8c4..909d4373 100644
--- a/src/llmtuner/train/rm/trainer.py
+++ b/src/llmtuner/train/rm/trainer.py
@@ -4,7 +4,7 @@ import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
-from llmtuner.extras.logging import get_logger
+from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py
index 52070027..e055e216 100644
--- a/src/llmtuner/train/rm/workflow.py
+++ b/src/llmtuner/train/rm/workflow.py
@@ -3,19 +3,19 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
-from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
-from llmtuner.extras.callbacks import FixValueHeadModelCallback
-from llmtuner.extras.misc import fix_valuehead_checkpoint
-from llmtuner.extras.ploting import plot_loss
-from llmtuner.model import load_model_and_tokenizer
-from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
-from llmtuner.train.rm.metric import compute_accuracy
-from llmtuner.train.rm.trainer import PairwiseTrainer
-from llmtuner.train.utils import create_modelcard_and_push
+from ...data import get_dataset, split_dataset
+from ...extras.callbacks import FixValueHeadModelCallback
+from ...extras.misc import fix_valuehead_checkpoint
+from ...extras.ploting import plot_loss
+from ...model import load_model_and_tokenizer
+from ...train.rm.collator import PairwiseDataCollatorWithPadding
+from ...train.rm.metric import compute_accuracy
+from ...train.rm.trainer import PairwiseTrainer
+from ...train.utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import TrainerCallback
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+ from ...hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm(
@@ -25,9 +25,8 @@ def run_rm(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
- dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
- dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
+ dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments
diff --git a/src/llmtuner/train/sft/__init__.py b/src/llmtuner/train/sft/__init__.py
index cb5448f4..f2f84e78 100644
--- a/src/llmtuner/train/sft/__init__.py
+++ b/src/llmtuner/train/sft/__init__.py
@@ -1 +1,4 @@
-from llmtuner.train.sft.workflow import run_sft
+from .workflow import run_sft
+
+
+__all__ = ["run_sft"]
diff --git a/src/llmtuner/train/sft/metric.py b/src/llmtuner/train/sft/metric.py
index 18db0b88..2741c66b 100644
--- a/src/llmtuner/train/sft/metric.py
+++ b/src/llmtuner/train/sft/metric.py
@@ -2,8 +2,8 @@ import numpy as np
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
-from llmtuner.extras.constants import IGNORE_INDEX
-from llmtuner.extras.packages import (
+from ...extras.constants import IGNORE_INDEX
+from ...extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py
index 291bbc7a..c8d9f039 100644
--- a/src/llmtuner/train/sft/trainer.py
+++ b/src/llmtuner/train/sft/trainer.py
@@ -6,8 +6,8 @@ import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
-from llmtuner.extras.constants import IGNORE_INDEX
-from llmtuner.extras.logging import get_logger
+from ...extras.constants import IGNORE_INDEX
+from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py
index 0e9bf7e4..6d3f34e8 100644
--- a/src/llmtuner/train/sft/workflow.py
+++ b/src/llmtuner/train/sft/workflow.py
@@ -3,18 +3,19 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
-from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
-from llmtuner.extras.constants import IGNORE_INDEX
-from llmtuner.extras.misc import get_logits_processor
-from llmtuner.extras.ploting import plot_loss
-from llmtuner.model import load_model_and_tokenizer
-from llmtuner.train.sft.metric import ComputeMetrics
-from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
-from llmtuner.train.utils import create_modelcard_and_push
+from ...data import get_dataset, split_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.misc import get_logits_processor
+from ...extras.ploting import plot_loss
+from ...model import load_model_and_tokenizer
+from ...train.sft.metric import ComputeMetrics
+from ...train.sft.trainer import CustomSeq2SeqTrainer
+from ...train.utils import create_modelcard_and_push
+
if TYPE_CHECKING:
from transformers import TrainerCallback
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
+ from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft(
@@ -25,9 +26,8 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
- dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
- dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
+ dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="sft")
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py
index 8705c98e..32f1cda0 100644
--- a/src/llmtuner/train/tuner.py
+++ b/src/llmtuner/train/tuner.py
@@ -2,14 +2,15 @@ import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import PreTrainedModel
-from llmtuner.extras.callbacks import LogCallback
-from llmtuner.extras.logging import get_logger
-from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
-from llmtuner.train.pt import run_pt
-from llmtuner.train.sft import run_sft
-from llmtuner.train.rm import run_rm
-from llmtuner.train.ppo import run_ppo
-from llmtuner.train.dpo import run_dpo
+from ..extras.callbacks import LogCallback
+from ..extras.logging import get_logger
+from ..hparams import get_train_args, get_infer_args
+from ..model import load_model_and_tokenizer
+from .pt import run_pt
+from .sft import run_sft
+from .rm import run_rm
+from .ppo import run_ppo
+from .dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py
index 4cc775eb..789986e4 100644
--- a/src/llmtuner/train/utils.py
+++ b/src/llmtuner/train/utils.py
@@ -1,15 +1,15 @@
import torch
from typing import TYPE_CHECKING, Optional, Union
-from llmtuner.extras.logging import get_logger
-from llmtuner.hparams import ModelArguments, FinetuningArguments
-from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
+from ..extras.logging import get_logger
+from ..hparams import ModelArguments, FinetuningArguments
+from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
- from llmtuner.hparams import DataArguments
+ from ..hparams import DataArguments
logger = get_logger(__name__)
diff --git a/src/llmtuner/webui/__init__.py b/src/llmtuner/webui/__init__.py
index a27c7f6e..3e82dd69 100644
--- a/src/llmtuner/webui/__init__.py
+++ b/src/llmtuner/webui/__init__.py
@@ -1 +1,4 @@
-from llmtuner.webui.interface import create_ui, create_web_demo
+from .interface import create_ui, create_web_demo
+
+
+__all__ = ["create_ui", "create_web_demo"]
diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py
index 08027e38..e211bb2a 100644
--- a/src/llmtuner/webui/chatter.py
+++ b/src/llmtuner/webui/chatter.py
@@ -2,14 +2,14 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
-from llmtuner.chat import ChatModel
-from llmtuner.extras.misc import torch_gc
-from llmtuner.hparams import GeneratingArguments
-from llmtuner.webui.common import get_save_dir
-from llmtuner.webui.locales import ALERTS
+from ..chat import ChatModel
+from ..extras.misc import torch_gc
+from ..hparams import GeneratingArguments
+from .common import get_save_dir
+from .locales import ALERTS
if TYPE_CHECKING:
- from llmtuner.webui.manager import Manager
+ from .manager import Manager
class WebChatModel(ChatModel):
@@ -105,6 +105,7 @@ class WebChatModel(ChatModel):
query: str,
history: List[Tuple[str, str]],
system: str,
+ tools: str,
max_new_tokens: int,
top_p: float,
temperature: float
@@ -112,7 +113,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
- query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
+ query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py
index 28d8a805..3d431aeb 100644
--- a/src/llmtuner/webui/common.py
+++ b/src/llmtuner/webui/common.py
@@ -5,7 +5,8 @@ from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
-from llmtuner.extras.constants import (
+from ..extras.constants import (
+ DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
PEFT_METHODS,
@@ -13,8 +14,7 @@ from llmtuner.extras.constants import (
TRAINING_STAGES,
DownloadSource
)
-from llmtuner.extras.misc import use_modelscope
-from llmtuner.hparams.data_args import DATA_CONFIG
+from ..extras.misc import use_modelscope
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py
index 32228b8e..2e9a87ec 100644
--- a/src/llmtuner/webui/components/__init__.py
+++ b/src/llmtuner/webui/components/__init__.py
@@ -1,6 +1,11 @@
-from llmtuner.webui.components.top import create_top
-from llmtuner.webui.components.train import create_train_tab
-from llmtuner.webui.components.eval import create_eval_tab
-from llmtuner.webui.components.infer import create_infer_tab
-from llmtuner.webui.components.export import create_export_tab
-from llmtuner.webui.components.chatbot import create_chat_box
+from .top import create_top
+from .train import create_train_tab
+from .eval import create_eval_tab
+from .infer import create_infer_tab
+from .export import create_export_tab
+from .chatbot import create_chat_box
+
+
+__all__ = [
+ "create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
+]
diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py
index 13e2dd4d..ebc1b71f 100644
--- a/src/llmtuner/webui/components/chatbot.py
+++ b/src/llmtuner/webui/components/chatbot.py
@@ -1,10 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple
+from ..utils import check_json_schema
+
+
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
- from llmtuner.webui.engine import Engine
+
+ from ..engine import Engine
def create_chat_box(
@@ -17,6 +21,7 @@ def create_chat_box(
with gr.Row():
with gr.Column(scale=4):
system = gr.Textbox(show_label=False)
+ tools = gr.Textbox(show_label=False, lines=2)
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@@ -27,9 +32,11 @@ def create_chat_box(
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
+ tools.input(check_json_schema, [tools])
+
submit_btn.click(
engine.chatter.predict,
- [chatbot, query, history, system, max_new_tokens, top_p, temperature],
+ [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
@@ -40,6 +47,7 @@ def create_chat_box(
return chat_box, chatbot, history, dict(
system=system,
+ tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py
index a74bd029..3a50065a 100644
--- a/src/llmtuner/webui/components/data.py
+++ b/src/llmtuner/webui/components/data.py
@@ -3,7 +3,7 @@ import json
import gradio as gr
from typing import TYPE_CHECKING, Any, Dict, Tuple
-from llmtuner.webui.common import DATA_CONFIG
+from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING:
from gradio.components import Component
diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py
index 0718c63e..d900ad29 100644
--- a/src/llmtuner/webui/components/eval.py
+++ b/src/llmtuner/webui/components/eval.py
@@ -1,12 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
-from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
-from llmtuner.webui.components.data import create_preview_box
+from ..common import list_dataset, DEFAULT_DATA_DIR
+from .data import create_preview_box
if TYPE_CHECKING:
from gradio.components import Component
- from llmtuner.webui.engine import Engine
+
+ from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py
index a4b65591..187cda64 100644
--- a/src/llmtuner/webui/components/export.py
+++ b/src/llmtuner/webui/components/export.py
@@ -1,13 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
-from llmtuner.train import export_model
-from llmtuner.webui.common import get_save_dir
-from llmtuner.webui.locales import ALERTS
+from ...train import export_model
+from ..common import get_save_dir
+from ..locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
- from llmtuner.webui.engine import Engine
+
+ from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py
index d6dd7eed..ba578f10 100644
--- a/src/llmtuner/webui/components/infer.py
+++ b/src/llmtuner/webui/components/infer.py
@@ -1,11 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
-from llmtuner.webui.components.chatbot import create_chat_box
+from .chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
- from llmtuner.webui.engine import Engine
+
+ from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py
index 74441ab2..b8468186 100644
--- a/src/llmtuner/webui/components/top.py
+++ b/src/llmtuner/webui/components/top.py
@@ -1,10 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
-from llmtuner.data.template import templates
-from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
-from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config
-from llmtuner.webui.utils import can_quantize
+from ...data import templates
+from ...extras.constants import METHODS, SUPPORTED_MODELS
+from ..common import get_model_path, get_template, list_adapters, save_config
+from ..utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py
index 5989c421..08e861f0 100644
--- a/src/llmtuner/webui/components/train.py
+++ b/src/llmtuner/webui/components/train.py
@@ -2,14 +2,15 @@ import gradio as gr
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
-from llmtuner.extras.constants import TRAINING_STAGES
-from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR
-from llmtuner.webui.components.data import create_preview_box
-from llmtuner.webui.utils import gen_plot
+from ...extras.constants import TRAINING_STAGES
+from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
+from ..components.data import create_preview_box
+from ..utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
- from llmtuner.webui.engine import Engine
+
+ from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py
index 991b281c..db60b5df 100644
--- a/src/llmtuner/webui/engine.py
+++ b/src/llmtuner/webui/engine.py
@@ -2,12 +2,12 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
-from llmtuner.webui.chatter import WebChatModel
-from llmtuner.webui.common import get_model_path, list_dataset, load_config
-from llmtuner.webui.locales import LOCALES
-from llmtuner.webui.manager import Manager
-from llmtuner.webui.runner import Runner
-from llmtuner.webui.utils import get_time
+from .chatter import WebChatModel
+from .common import get_model_path, list_dataset, load_config
+from .locales import LOCALES
+from .manager import Manager
+from .runner import Runner
+from .utils import get_time
class Engine:
diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py
index 2525c3fd..39ddca04 100644
--- a/src/llmtuner/webui/interface.py
+++ b/src/llmtuner/webui/interface.py
@@ -2,7 +2,7 @@ import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version
-from llmtuner.webui.components import (
+from .components import (
create_top,
create_train_tab,
create_eval_tab,
@@ -10,9 +10,9 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
-from llmtuner.webui.common import save_config
-from llmtuner.webui.css import CSS
-from llmtuner.webui.engine import Engine
+from .common import save_config
+from .css import CSS
+from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py
index d6f9d31f..9ba08e25 100644
--- a/src/llmtuner/webui/locales.py
+++ b/src/llmtuner/webui/locales.py
@@ -521,6 +521,14 @@ LOCALES = {
"placeholder": "系统提示词(非必填)"
}
},
+ "tools": {
+ "en": {
+ "placeholder": "Tools (optional)"
+ },
+ "zh": {
+ "placeholder": "工具列表(非必填)"
+ }
+ },
"query": {
"en": {
"placeholder": "Input..."
diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py
index 374d72a3..5d8efbfb 100644
--- a/src/llmtuner/webui/runner.py
+++ b/src/llmtuner/webui/runner.py
@@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
-from llmtuner.extras.callbacks import LogCallback
-from llmtuner.extras.constants import TRAINING_STAGES
-from llmtuner.extras.logging import LoggerHandler
-from llmtuner.extras.misc import get_device_count, torch_gc
-from llmtuner.train import run_exp
-from llmtuner.webui.common import get_module, get_save_dir, load_config
-from llmtuner.webui.locales import ALERTS
-from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
+from ..extras.callbacks import LogCallback
+from ..extras.constants import TRAINING_STAGES
+from ..extras.logging import LoggerHandler
+from ..extras.misc import get_device_count, torch_gc
+from ..train import run_exp
+from .common import get_module, get_save_dir, load_config
+from .locales import ALERTS
+from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
- from llmtuner.webui.manager import Manager
+ from .manager import Manager
class Runner:
diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py
index 4579d296..6bd80093 100644
--- a/src/llmtuner/webui/utils.py
+++ b/src/llmtuner/webui/utils.py
@@ -4,12 +4,12 @@ import gradio as gr
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
-from llmtuner.extras.packages import is_matplotlib_available
-from llmtuner.extras.ploting import smooth
-from llmtuner.webui.common import get_save_dir
+from ..extras.packages import is_matplotlib_available
+from ..extras.ploting import smooth
+from .common import get_save_dir
if TYPE_CHECKING:
- from llmtuner.extras.callbacks import LogCallback
+ from ..extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure
@@ -41,6 +41,13 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True)
+def check_json_schema(text: str) -> None:
+ try:
+ json.loads(text)
+ except json.JSONDecodeError:
+ gr.Warning("Invalid JSON schema")
+
+
def gen_cmd(args: Dict[str, Any]) -> str:
args.pop("disable_tqdm", None)
args["plot_loss"] = args.get("do_train", None)
diff --git a/tests/llamafy_internlm2.py b/tests/llamafy_internlm2.py
index 3cd59f96..8fb1448c 100644
--- a/tests/llamafy_internlm2.py
+++ b/tests/llamafy_internlm2.py
@@ -1,6 +1,7 @@
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
+# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import os
import fire
@@ -43,19 +44,18 @@ def save_weight(
llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
- elif "attention_norm" in key:
- llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "wqkv" in key:
- proj_size = value.size(0)
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
- q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads
- kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
+ q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
+ kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
+ elif "attention_norm" in key:
+ llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key: