mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-15 17:42:48 +08:00
support function calling
Former-commit-id: d9f1cae35150cce594a7abd96dd2beb811fa33f2
This commit is contained in:
parent
71306bbfb1
commit
4e3bfb799d
@ -165,9 +165,13 @@
|
|||||||
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
||||||
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "messages",
|
"messages": "messages"
|
||||||
"role": "role",
|
},
|
||||||
"content": "content"
|
"tags": {
|
||||||
|
"role_tag": "role",
|
||||||
|
"content_tag": "content",
|
||||||
|
"user_tag": "human",
|
||||||
|
"assistant_tag": "assistant"
|
||||||
},
|
},
|
||||||
"formatting": "sharegpt"
|
"formatting": "sharegpt"
|
||||||
},
|
},
|
||||||
@ -180,9 +184,13 @@
|
|||||||
"hf_hub_url": "lmsys/lmsys-chat-1m",
|
"hf_hub_url": "lmsys/lmsys-chat-1m",
|
||||||
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
|
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversation",
|
"messages": "conversation"
|
||||||
"role": "role",
|
},
|
||||||
"content": "content"
|
"tags": {
|
||||||
|
"role_tag": "role",
|
||||||
|
"content_tag": "content",
|
||||||
|
"user_tag": "human",
|
||||||
|
"assistant_tag": "assistant"
|
||||||
},
|
},
|
||||||
"formatting": "sharegpt"
|
"formatting": "sharegpt"
|
||||||
},
|
},
|
||||||
@ -190,6 +198,14 @@
|
|||||||
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
|
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
|
||||||
"formatting": "sharegpt"
|
"formatting": "sharegpt"
|
||||||
},
|
},
|
||||||
|
"glaive_func_call": {
|
||||||
|
"file_name": "glaive_func_call.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"tool": "tools"
|
||||||
|
}
|
||||||
|
},
|
||||||
"hh_rlhf_en": {
|
"hh_rlhf_en": {
|
||||||
"script_url": "hh_rlhf_en",
|
"script_url": "hh_rlhf_en",
|
||||||
"columns": {
|
"columns": {
|
||||||
|
68
data/glaive_func_call.json
Normal file
68
data/glaive_func_call.json
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I need a new password. Can you generate one for me?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "Of course. How long would you like your password to be? And would you like it to include symbols?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I would like it to be 12 characters long and yes, please include symbols."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "function_call",
|
||||||
|
"value": "{\"name\": \"generate_password\", \"arguments\": {\"length\": 12, \"include_symbols\": true}}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "observation",
|
||||||
|
"value": "{\"password\": \"4&7j#9@1Q6*\"}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "Here is your new password: 4&7j#9@1Q6*. Please make sure to save it in a secure location."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": "[{\"name\": \"generate_password\", \"description\": \"Generate a random password\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"The length of the password\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"Whether to include symbols in the password\"}}, \"required\": [\"length\"]}}, {\"name\": \"create_task\", \"description\": \"Create a new task in a task management system\", \"parameters\": {\"type\": \"object\", \"properties\": {\"title\": {\"type\": \"string\", \"description\": \"The title of the task\"}, \"due_date\": {\"type\": \"string\", \"format\": \"date\", \"description\": \"The due date of the task\"}, \"priority\": {\"type\": \"string\", \"enum\": [\"low\", \"medium\", \"high\"], \"description\": \"The priority of the task\"}}, \"required\": [\"title\", \"due_date\", \"priority\"]}}]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "Can you tell me the latest news headlines for the United States?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "function_call",
|
||||||
|
"value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"United States\"}}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "observation",
|
||||||
|
"value": "{\"headlines\": [\"Biden announces new vaccine mandates\", \"Hurricane Ida devastates Louisiana\", \"Apple unveils new iPhone\", \"NASA's Perseverance rover collects first Mars rock sample\"]}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "Here are the latest news headlines for the United States:"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "That's interesting. What about the news in France?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "function_call",
|
||||||
|
"value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"France\"}}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "observation",
|
||||||
|
"value": "{\"headlines\": [\"France recalls ambassadors to US and Australia\", \"French election: Macron's party braces for tough fight\", \"Louvre Museum to undergo major overhaul\", \"France to offer free birth control to all women under 25\"]}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "Here are the latest news headlines for France:"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": "[{\"name\": \"get_news_headlines\", \"description\": \"Get the latest news headlines\", \"parameters\": {\"type\": \"object\", \"properties\": {\"country\": {\"type\": \"string\", \"description\": \"The country for which to fetch news\"}}, \"required\": [\"country\"]}}]"
|
||||||
|
}
|
||||||
|
]
|
@ -9,7 +9,6 @@ scipy
|
|||||||
einops
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
tiktoken
|
|
||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
|
@ -8,3 +8,12 @@ from llmtuner.webui import create_ui, create_web_demo
|
|||||||
|
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.4.0"
|
||||||
|
__all__ = [
|
||||||
|
"create_app",
|
||||||
|
"ChatModel",
|
||||||
|
"Evaluator",
|
||||||
|
"export_model",
|
||||||
|
"run_exp",
|
||||||
|
"create_ui",
|
||||||
|
"create_web_demo"
|
||||||
|
]
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.api.app import create_app
|
from .app import create_app
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["create_app"]
|
||||||
|
@ -5,7 +5,7 @@ from typing import List, Tuple
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from llmtuner.api.protocol import (
|
from .protocol import (
|
||||||
Role,
|
Role,
|
||||||
Finish,
|
Finish,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
@ -21,9 +21,9 @@ from llmtuner.api.protocol import (
|
|||||||
ScoreEvaluationRequest,
|
ScoreEvaluationRequest,
|
||||||
ScoreEvaluationResponse
|
ScoreEvaluationResponse
|
||||||
)
|
)
|
||||||
from llmtuner.chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from llmtuner.extras.packages import (
|
from ..extras.packages import (
|
||||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum, unique
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class Finish(str, Enum):
|
class Finish(str, Enum):
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
LENGTH = "length"
|
LENGTH = "length"
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.chat.chat_model import ChatModel
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ChatModel"]
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import tiktoken
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
|
from ..hparams import get_infer_args
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -139,11 +139,6 @@ class ChatModel:
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=True)
|
|
||||||
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||||
|
|
||||||
@ -153,7 +148,7 @@ class ChatModel:
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs
|
add_special_tokens=True
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from llmtuner.data.loader import get_dataset
|
from .loader import get_dataset
|
||||||
from llmtuner.data.preprocess import preprocess_dataset
|
from .template import get_template_and_fix_tokenizer, templates
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from .utils import split_dataset
|
||||||
from llmtuner.data.utils import split_dataset
|
|
||||||
|
|
||||||
|
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"]
|
||||||
|
106
src/llmtuner/data/aligner.py
Normal file
106
src/llmtuner/data/aligner.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
|
||||||
|
from .utils import Role
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
|
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
|
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
|
||||||
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
|
prompt = []
|
||||||
|
if dataset_attr.history:
|
||||||
|
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||||
|
prompt.append({"role": Role.USER, "content": old_prompt})
|
||||||
|
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
||||||
|
|
||||||
|
instruction = examples[dataset_attr.prompt][i]
|
||||||
|
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||||
|
instruction += "\n" + examples[dataset_attr.query][i]
|
||||||
|
prompt.append({"role": Role.USER, "content": instruction})
|
||||||
|
|
||||||
|
if isinstance(examples[dataset_attr.response][i], list):
|
||||||
|
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
||||||
|
else:
|
||||||
|
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||||
|
|
||||||
|
outputs["prompt"].append(prompt)
|
||||||
|
outputs["response"].append(response)
|
||||||
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
|
outputs["tool"].append("")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
|
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
|
||||||
|
tag_mapping = {
|
||||||
|
dataset_attr.user_tag: Role.USER,
|
||||||
|
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||||
|
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||||
|
dataset_attr.function_tag: Role.FUNCTION
|
||||||
|
}
|
||||||
|
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||||
|
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
|
||||||
|
if len(messages) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt = []
|
||||||
|
response = []
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if turn_idx % 2 == 0:
|
||||||
|
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
|
||||||
|
else:
|
||||||
|
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
|
||||||
|
|
||||||
|
if message[dataset_attr.role_tag] not in accept_tags:
|
||||||
|
raise ValueError("Invalid role tag.")
|
||||||
|
|
||||||
|
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
|
||||||
|
|
||||||
|
last_message = prompt.pop(-1)
|
||||||
|
response.append(last_message)
|
||||||
|
outputs["prompt"].append(prompt)
|
||||||
|
outputs["response"].append(response)
|
||||||
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
|
outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def align_dataset(
|
||||||
|
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Aligned dataset:
|
||||||
|
prompt: [{"role": "user", "content": "..."}]
|
||||||
|
response: [{"role": "assistant", "content": "..."}]
|
||||||
|
system: "..."
|
||||||
|
tool: "..."
|
||||||
|
"""
|
||||||
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||||
|
else:
|
||||||
|
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||||
|
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
|
desc="Converting format of dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset.map(
|
||||||
|
convert_func,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=column_names,
|
||||||
|
**kwargs
|
||||||
|
)
|
99
src/llmtuner/data/formatter.py
Normal file
99
src/llmtuner/data/formatter.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Literal, Union
|
||||||
|
|
||||||
|
|
||||||
|
JSON_FORMAT_PROMPT = (
|
||||||
|
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_SYSTEM_PROMPT = (
|
||||||
|
"You have access to the following tools:\n{tool_text}"
|
||||||
|
"Use the following format to answer the question:\n"
|
||||||
|
"```\n"
|
||||||
|
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
|
||||||
|
"Action Input: the input to the action{format_prompt}.\n"
|
||||||
|
"```"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StringFormatter:
|
||||||
|
container: List[Union[str, Dict[str, str]]]
|
||||||
|
|
||||||
|
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
|
||||||
|
elements = []
|
||||||
|
for elem in self.container:
|
||||||
|
if isinstance(elem, str):
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
elem = elem.replace("{{" + name + "}}", value)
|
||||||
|
elements.append(elem)
|
||||||
|
elif isinstance(elem, (dict, set)):
|
||||||
|
elements.append(elem)
|
||||||
|
else:
|
||||||
|
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionFormatter:
|
||||||
|
container: List[Union[str, Dict[str, str]]]
|
||||||
|
|
||||||
|
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||||
|
try:
|
||||||
|
function = json.loads(content)
|
||||||
|
name = json.dumps(function["name"], ensure_ascii=False)
|
||||||
|
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
name, arguments = "", ""
|
||||||
|
|
||||||
|
elements = []
|
||||||
|
for elem in self.container:
|
||||||
|
if isinstance(elem, str):
|
||||||
|
elem = elem.replace("{{name}}", name)
|
||||||
|
elem = elem.replace("{{arguments}}", arguments)
|
||||||
|
elements.append(elem)
|
||||||
|
elif isinstance(elem, (dict, set)):
|
||||||
|
elements.append(elem)
|
||||||
|
else:
|
||||||
|
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolFormatter:
|
||||||
|
type: Literal["default"]
|
||||||
|
|
||||||
|
def _default(self, tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
tool_names = []
|
||||||
|
for tool in tools:
|
||||||
|
param_text = ""
|
||||||
|
for name, param in tool["parameters"]["properties"].items():
|
||||||
|
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||||
|
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||||
|
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
|
||||||
|
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||||
|
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||||
|
)
|
||||||
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
|
return TOOL_SYSTEM_PROMPT.format(
|
||||||
|
tool_text=tool_text,
|
||||||
|
tool_names=", ".join(tool_names),
|
||||||
|
format_prompt=JSON_FORMAT_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||||
|
try:
|
||||||
|
tools = json.loads(content)
|
||||||
|
if self.type == "default":
|
||||||
|
return [self._default(tools)]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return [""]
|
@ -1,160 +1,114 @@
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, List, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
from llmtuner.data.utils import checksum
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.extras.logging import get_logger
|
from .utils import checksum
|
||||||
|
from .parser import get_dataset_list
|
||||||
|
from .aligner import align_dataset
|
||||||
|
from .template import get_template_and_fix_tokenizer
|
||||||
|
from .preprocess import get_preprocess_and_print_func
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
from ..hparams import ModelArguments, DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def load_single_dataset(
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
):
|
||||||
max_samples = data_args.max_samples
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||||
|
data_path = dataset_attr.dataset_name
|
||||||
|
data_name = dataset_attr.subset
|
||||||
|
data_dir = dataset_attr.folder
|
||||||
|
|
||||||
if data_args.cache_path is not None:
|
elif dataset_attr.load_from == "script":
|
||||||
if os.path.exists(data_args.cache_path):
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
data_name = dataset_attr.subset
|
||||||
dataset = load_from_disk(data_args.cache_path)
|
data_dir = dataset_attr.folder
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.to_iterable_dataset()
|
|
||||||
return dataset
|
|
||||||
elif data_args.streaming:
|
|
||||||
raise ValueError("Turn off dataset streaming to save cache files.")
|
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
elif dataset_attr.load_from == "file":
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
data_files = []
|
||||||
|
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
if os.path.isdir(local_path): # is directory
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
for file_name in os.listdir(local_path):
|
||||||
data_path = dataset_attr.dataset_name
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
data_name = dataset_attr.subset
|
if data_path is None:
|
||||||
data_dir = dataset_attr.folder
|
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
elif dataset_attr.load_from == "script":
|
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
raise ValueError("File types should be identical.")
|
||||||
data_name = dataset_attr.subset
|
elif os.path.isfile(local_path): # is file
|
||||||
elif dataset_attr.load_from == "file":
|
data_files.append(local_path)
|
||||||
data_files = []
|
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
|
||||||
if os.path.isdir(local_path): # is directory
|
|
||||||
for file_name in os.listdir(local_path):
|
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
|
||||||
if data_path is None:
|
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
|
||||||
elif os.path.isfile(local_path): # is file
|
|
||||||
data_files.append(local_path)
|
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
raise ValueError("File not found.")
|
|
||||||
|
|
||||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
|
||||||
checksum(data_files, dataset_attr.dataset_sha1)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if data_path is None:
|
||||||
try:
|
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||||
from modelscope import MsDataset
|
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
checksum(data_files, dataset_attr.dataset_sha1)
|
||||||
dataset = MsDataset.load(
|
else:
|
||||||
dataset_name=data_path,
|
raise NotImplementedError
|
||||||
subset_name=data_name,
|
|
||||||
data_dir=data_dir,
|
if dataset_attr.load_from == "ms_hub":
|
||||||
data_files=data_files,
|
try:
|
||||||
split=data_args.split,
|
from modelscope import MsDataset
|
||||||
cache_dir=cache_dir,
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||||
token=model_args.ms_hub_token,
|
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
).to_hf_dataset()
|
dataset = MsDataset.load(
|
||||||
except ImportError:
|
dataset_name=data_path,
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
subset_name=data_name,
|
||||||
else:
|
|
||||||
dataset = load_dataset(
|
|
||||||
path=data_path,
|
|
||||||
name=data_name,
|
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||||
)
|
).to_hf_dataset()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(
|
||||||
|
path=data_path,
|
||||||
|
name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
split=data_args.split,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
token=model_args.hf_hub_token,
|
||||||
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||||
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
num_samples = min(data_args.max_samples, len(dataset))
|
||||||
|
dataset = dataset.select(range(num_samples))
|
||||||
|
|
||||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
# convert dataset from sharegpt format to alpaca format
|
|
||||||
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
|
|
||||||
for i, msg_list in enumerate(examples[dataset_attr.messages]):
|
|
||||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
|
||||||
if len(msg_list) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
msg_pairs = []
|
|
||||||
user_role, assistant_role = None, None
|
|
||||||
for idx in range(0, len(msg_list), 2):
|
|
||||||
if user_role is None and assistant_role is None:
|
|
||||||
user_role = msg_list[idx][dataset_attr.role]
|
|
||||||
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
msg_list[idx][dataset_attr.role] != user_role
|
|
||||||
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
|
||||||
):
|
|
||||||
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
|
||||||
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
|
||||||
|
|
||||||
if len(msg_pairs) != 0:
|
def merge_dataset(
|
||||||
outputs["prompt"].append(msg_pairs[-1][0])
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
outputs["query"].append("")
|
data_args: "DataArguments",
|
||||||
outputs["response"].append(msg_pairs[-1][1])
|
training_args: "Seq2SeqTrainingArguments"
|
||||||
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
if len(all_datasets) == 1:
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
if dataset_attr.formatting == "sharegpt": # convert format
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
|
||||||
desc="Converting format of dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
convert_format,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
|
|
||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
|
||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
@ -166,8 +120,72 @@ def get_dataset(
|
|||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
probabilities=data_args.interleave_probs,
|
probabilities=data_args.interleave_probs,
|
||||||
seed=data_args.seed,
|
seed=training_args.seed,
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
|
# split: Optional[str] = "train", # TODO: add split
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
|
# Load from cache
|
||||||
|
if data_args.cache_path is not None:
|
||||||
|
if os.path.exists(data_args.cache_path):
|
||||||
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
|
dataset = load_from_disk(data_args.cache_path)
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.to_iterable_dataset()
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
if data_args.streaming:
|
||||||
|
raise ValueError("Turn off dataset streaming to save cache files.")
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
|
all_datasets = []
|
||||||
|
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||||
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||||
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
|
tokenizer, template, data_args, training_args, stage
|
||||||
|
)
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
|
desc="Running tokenizer on dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
preprocess_func,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=column_names,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||||
|
if training_args.should_save:
|
||||||
|
dataset.save_to_disk(data_args.cache_path)
|
||||||
|
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
try:
|
||||||
|
print_function(next(iter(dataset)))
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("Empty dataset!")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
101
src/llmtuner/data/parser.py
Normal file
101
src/llmtuner/data/parser.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..extras.constants import DATA_CONFIG
|
||||||
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetAttr:
|
||||||
|
|
||||||
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
|
dataset_name: Optional[str] = None
|
||||||
|
dataset_sha1: Optional[str] = None
|
||||||
|
subset: Optional[str] = None
|
||||||
|
folder: Optional[str] = None
|
||||||
|
ranking: Optional[bool] = False
|
||||||
|
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||||
|
|
||||||
|
system: Optional[str] = None
|
||||||
|
|
||||||
|
prompt: Optional[str] = "instruction"
|
||||||
|
query: Optional[str] = "input"
|
||||||
|
response: Optional[str] = "output"
|
||||||
|
history: Optional[str] = None
|
||||||
|
|
||||||
|
messages: Optional[str] = "conversations"
|
||||||
|
tool: Optional[str] = None
|
||||||
|
|
||||||
|
role_tag: Optional[str] = "from"
|
||||||
|
content_tag: Optional[str] = "value"
|
||||||
|
user_tag: Optional[str] = "human"
|
||||||
|
assistant_tag: Optional[str] = "gpt"
|
||||||
|
observation_tag: Optional[str] = "observation"
|
||||||
|
function_tag: Optional[str] = "function_call"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.dataset_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
|
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||||
|
try:
|
||||||
|
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
except Exception as err:
|
||||||
|
if data_args.dataset is not None:
|
||||||
|
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
|
||||||
|
dataset_info = None
|
||||||
|
|
||||||
|
if data_args.interleave_probs is not None:
|
||||||
|
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||||
|
|
||||||
|
dataset_list: List[DatasetAttr] = []
|
||||||
|
for name in dataset_names:
|
||||||
|
if name not in dataset_info:
|
||||||
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
|
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||||
|
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||||
|
|
||||||
|
if has_hf_url or has_ms_url:
|
||||||
|
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||||
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
|
elif "script_url" in dataset_info[name]:
|
||||||
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr(
|
||||||
|
"file",
|
||||||
|
dataset_name=dataset_info[name]["file_name"],
|
||||||
|
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||||
|
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||||
|
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||||
|
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||||
|
|
||||||
|
if "columns" in dataset_info[name]:
|
||||||
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
column_names = ["prompt", "query", "response", "history"]
|
||||||
|
else:
|
||||||
|
column_names = ["messages", "tool"]
|
||||||
|
|
||||||
|
column_names += ["system"]
|
||||||
|
for column_name in column_names:
|
||||||
|
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
|
||||||
|
|
||||||
|
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||||
|
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
|
||||||
|
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
|
||||||
|
|
||||||
|
dataset_list.append(dataset_attr)
|
||||||
|
|
||||||
|
return dataset_list
|
@ -1,272 +1,241 @@
|
|||||||
import os
|
from functools import partial
|
||||||
import tiktoken
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def preprocess_pretrain_dataset(
|
||||||
for i in range(len(examples["prompt"])):
|
examples: Dict[str, List[Any]],
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
|
||||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
|
||||||
history = examples["history"][i] if "history" in examples else None
|
|
||||||
system = examples["system"][i] if "system" in examples else None
|
|
||||||
yield query, response, history, system
|
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
|
||||||
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
|
||||||
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
|
||||||
max_source_len = data_args.cutoff_len - max_target_len
|
|
||||||
return max_source_len, max_target_len
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
data_args: "DataArguments"
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build grouped texts with format `X1 X2 X3 ...`
|
||||||
|
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
||||||
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
|
for i in range(len(tokenized_examples["input_ids"])):
|
||||||
|
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
|
||||||
|
tokenized_examples["attention_mask"][i] += [1]
|
||||||
|
|
||||||
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
|
block_size = data_args.cutoff_len
|
||||||
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of cutoff_len
|
||||||
|
result = {
|
||||||
|
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||||
|
for k, t in concatenated_examples.items()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_supervised_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
) -> Dict[str, List[List[int]]]:
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
for i in range(len(examples["prompt"])):
|
||||||
return dataset # already preprocessed
|
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
|
||||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=True)
|
|
||||||
|
|
||||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
|
||||||
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
|
||||||
setattr(tokenizer, "add_eos_token", True)
|
|
||||||
|
|
||||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
|
||||||
block_size = data_args.cutoff_len
|
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
# split by chunks of cutoff_len
|
|
||||||
result = {
|
|
||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
|
||||||
for k, t in concatenated_examples.items()
|
|
||||||
}
|
|
||||||
# make sure the saved tokenizer is the same as the original one
|
|
||||||
if hasattr(tokenizer, "add_eos_token"):
|
|
||||||
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
|
||||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels = [], []
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
|
||||||
tokenizer, query, response, history, system
|
|
||||||
)):
|
|
||||||
source_len, target_len = len(source_ids), len(target_ids)
|
|
||||||
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
|
||||||
if source_len > max_source_len:
|
|
||||||
source_ids = source_ids[:max_source_len]
|
|
||||||
if target_len > max_target_len:
|
|
||||||
target_ids = target_ids[:max_target_len]
|
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
|
||||||
source_mask = source_ids
|
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
|
||||||
else:
|
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
input_ids += [tokenizer.eos_token_id]
|
|
||||||
labels += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
if len(input_ids) > data_args.cutoff_len:
|
|
||||||
input_ids = input_ids[:data_args.cutoff_len]
|
|
||||||
labels = labels[:data_args.cutoff_len]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for query, response, history, system in construct_example(examples):
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
||||||
continue
|
)):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
input_ids += source_ids + target_ids
|
||||||
tokenizer, query, response, history, system
|
labels += source_mask + target_ids
|
||||||
)):
|
|
||||||
if data_args.train_on_prompt:
|
|
||||||
source_mask = source_ids
|
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
|
||||||
else:
|
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
|
||||||
input_ids += source_ids + target_ids
|
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
input_ids += [tokenizer.eos_token_id]
|
input_ids += [tokenizer.eos_token_id]
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
total_length = len(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
block_size = data_args.cutoff_len
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
model_inputs["labels"].append(labels)
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
# split by chunks of cutoff_len
|
|
||||||
for i in range(0, total_length, block_size):
|
|
||||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
|
||||||
model_inputs["labels"].append(labels[i: i + block_size])
|
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
def preprocess_packed_supervised_dataset(
|
||||||
if not (isinstance(query, str) and query != ""):
|
examples: Dict[str, List[Any]],
|
||||||
continue
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
input_ids, labels = [], []
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
|
tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000
|
||||||
|
)):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
if template.efficient_eos:
|
input_ids += source_ids + target_ids
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
if len(input_ids) > data_args.cutoff_len:
|
if template.efficient_eos:
|
||||||
input_ids = input_ids[:data_args.cutoff_len]
|
input_ids += [tokenizer.eos_token_id]
|
||||||
if len(labels) > data_args.cutoff_len:
|
labels += [tokenizer.eos_token_id]
|
||||||
labels = labels[:data_args.cutoff_len]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
total_length = len(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
block_size = data_args.cutoff_len
|
||||||
model_inputs["labels"].append(labels)
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of cutoff_len
|
||||||
|
for i in range(0, total_length, block_size):
|
||||||
|
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||||
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
|
model_inputs["labels"].append(labels[i: i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
|
||||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
|
||||||
for query, response, history, system in construct_example(examples):
|
|
||||||
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
|
||||||
continue
|
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
def preprocess_unsupervised_dataset(
|
||||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
if template.efficient_eos:
|
for i in range(len(examples["prompt"])):
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
continue
|
||||||
|
|
||||||
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
input_ids, labels = template.encode_oneturn(
|
||||||
if source_len > max_source_len:
|
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
||||||
prompt_ids = prompt_ids[:max_source_len]
|
|
||||||
if target_len > max_target_len:
|
|
||||||
chosen_ids = chosen_ids[:max_target_len]
|
|
||||||
rejected_ids = rejected_ids[:max_target_len]
|
|
||||||
|
|
||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
|
||||||
model_inputs["chosen_ids"].append(chosen_ids)
|
|
||||||
model_inputs["rejected_ids"].append(rejected_ids)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
|
||||||
print("labels:\n{}".format(
|
|
||||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
|
||||||
))
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
|
||||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
|
||||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
|
||||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
|
||||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
|
||||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
if stage == "pt":
|
|
||||||
preprocess_func = preprocess_pretrain_dataset
|
|
||||||
print_function = print_unsupervised_dataset_example
|
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
|
||||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
|
||||||
print_function = print_supervised_dataset_example
|
|
||||||
elif stage == "rm":
|
|
||||||
preprocess_func = preprocess_pairwise_dataset
|
|
||||||
print_function = print_pairwise_dataset_example
|
|
||||||
else:
|
|
||||||
preprocess_func = preprocess_unsupervised_dataset
|
|
||||||
print_function = print_unsupervised_dataset_example
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
|
||||||
desc="Running tokenizer on dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
preprocess_func,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
if template.efficient_eos:
|
||||||
if training_args.should_save:
|
labels += [tokenizer.eos_token_id]
|
||||||
dataset.save_to_disk(data_args.cache_path)
|
|
||||||
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
|
|
||||||
|
|
||||||
if training_args.should_log:
|
model_inputs["input_ids"].append(input_ids)
|
||||||
try:
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
print_function(next(iter(dataset)))
|
model_inputs["labels"].append(labels)
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError("Empty dataset!")
|
|
||||||
|
|
||||||
return dataset
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_pairwise_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||||
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
|
||||||
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
|
tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
||||||
|
)
|
||||||
|
_, rejected_ids = template.encode_oneturn(
|
||||||
|
tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
|
||||||
|
)
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print("labels:\n{}".format(
|
||||||
|
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||||
|
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||||
|
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||||
|
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||||
|
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||||
|
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
|
||||||
|
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocess_and_print_func(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
|
) -> Tuple[Callable, Callable]:
|
||||||
|
if stage == "pt":
|
||||||
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
|
if data_args.sft_packing:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
|
||||||
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "rm":
|
||||||
|
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
|
||||||
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||||
|
else:
|
||||||
|
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
|
||||||
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
return preprocess_func, print_function
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import tiktoken
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from .utils import Role
|
||||||
|
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@ -14,28 +16,30 @@ logger = get_logger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Template:
|
class Template:
|
||||||
|
|
||||||
prefix: List[Union[str, Dict[str, str]]]
|
format_user: Callable
|
||||||
prompt: List[Union[str, Dict[str, str]]]
|
format_assistant: Callable
|
||||||
|
format_system: Callable
|
||||||
|
format_tool: Callable
|
||||||
|
format_observation: Callable
|
||||||
|
format_function: Callable
|
||||||
system: str
|
system: str
|
||||||
sep: List[Union[str, Dict[str, str]]]
|
separator: List[Union[str, Dict[str, str]]]
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
use_history: bool
|
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
query: str,
|
messages: List[Dict[str, str]],
|
||||||
resp: str,
|
system: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
tool: str,
|
||||||
system: Optional[str] = None
|
cutoff_len: int
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
"""
|
"""
|
||||||
system, history = self._format(query, resp, history, system)
|
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
|
||||||
encoded_pairs = self._encode(tokenizer, system, history)
|
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||||
@ -46,109 +50,75 @@ class Template:
|
|||||||
def encode_multiturn(
|
def encode_multiturn(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
query: str,
|
messages: List[Dict[str, str]],
|
||||||
resp: str,
|
system: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
tool: str,
|
||||||
system: Optional[str] = None
|
cutoff_len: int
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
system, history = self._format(query, resp, history, system)
|
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
|
||||||
encoded_pairs = self._encode(tokenizer, system, history)
|
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
def _format(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
resp: str,
|
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
system: Optional[str] = None
|
|
||||||
) -> Tuple[str, List[Tuple[str, str]]]:
|
|
||||||
r"""
|
|
||||||
Aligns inputs to the standard format.
|
|
||||||
"""
|
|
||||||
system = system or self.system # use system if provided
|
|
||||||
history = history if (history and self.use_history) else []
|
|
||||||
history = history + [(query, resp)]
|
|
||||||
return system, history
|
|
||||||
|
|
||||||
def _get_special_ids(
|
|
||||||
self,
|
|
||||||
tokenizer: "PreTrainedTokenizer"
|
|
||||||
) -> Tuple[List[int], List[int]]:
|
|
||||||
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
|
||||||
bos_ids = [tokenizer.bos_token_id]
|
|
||||||
else: # baichuan, gpt2, qwen, yi models have no bos token
|
|
||||||
bos_ids = []
|
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None:
|
|
||||||
raise ValueError("EOS token is required.")
|
|
||||||
|
|
||||||
if self.efficient_eos:
|
|
||||||
eos_ids = []
|
|
||||||
else:
|
|
||||||
eos_ids = [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
return bos_ids, eos_ids
|
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
history: List[Tuple[str, str]]
|
tool: str,
|
||||||
|
cutoff_len: int
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: bos + prefix + sep + query resp + eos
|
Turn 0: system + query resp + eos
|
||||||
Turn t: sep + bos + query resp + eos
|
Turn t: sep + query resp + eos
|
||||||
"""
|
"""
|
||||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
system = system or self.system
|
||||||
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
encoded_messages = []
|
||||||
encoded_pairs = []
|
for i, message in enumerate(messages):
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
elements = []
|
||||||
if turn_idx == 0:
|
if i == 0 and (system or tool):
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
tool_text = self.format_tool(content=tool)[0] if tool else ""
|
||||||
if len(prefix_ids) != 0: # has prefix
|
elements += self.format_system(content=(system + tool_text))
|
||||||
prefix_ids = bos_ids + prefix_ids + sep_ids
|
elif i > 0 and i % 2 == 0:
|
||||||
else:
|
elements += self.separator
|
||||||
prefix_ids = bos_ids
|
|
||||||
else:
|
|
||||||
prefix_ids = sep_ids + bos_ids
|
|
||||||
|
|
||||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
|
if message["role"] == Role.USER:
|
||||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
elements += self.format_user(content=message["content"], idx=str(i // 2))
|
||||||
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
elif message["role"] == Role.ASSISTANT:
|
||||||
return encoded_pairs
|
elements += self.format_assistant(content=message["content"])
|
||||||
|
elif message["role"] == Role.OBSERVATION:
|
||||||
|
elements += self.format_observation(content=message["content"])
|
||||||
|
elif message["role"] == Role.FUNCTION:
|
||||||
|
elements += self.format_function(content=message["content"])
|
||||||
|
|
||||||
def _convert_inputs_to_ids(
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
|
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
|
||||||
|
def _convert_elements_to_ids(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
context: List[Union[str, Dict[str, str]]],
|
elements: List[Union[str, Dict[str, str]]]
|
||||||
system: Optional[str] = None,
|
|
||||||
query: Optional[str] = None,
|
|
||||||
idx: Optional[str] = None
|
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
r"""
|
r"""
|
||||||
Converts context to token ids.
|
Converts elements to token ids.
|
||||||
"""
|
"""
|
||||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=False)
|
|
||||||
|
|
||||||
token_ids = []
|
token_ids = []
|
||||||
for elem in context:
|
for elem in elements:
|
||||||
if isinstance(elem, str):
|
if isinstance(elem, str):
|
||||||
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
|
||||||
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
|
||||||
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
|
||||||
if len(elem) != 0:
|
if len(elem) != 0:
|
||||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
|
||||||
elif isinstance(elem, dict):
|
elif isinstance(elem, dict):
|
||||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||||
|
elif isinstance(elem, set):
|
||||||
|
if "bos_token" in elem and tokenizer.bos_token_id:
|
||||||
|
token_ids = token_ids + [tokenizer.bos_token_id]
|
||||||
|
elif "eos_token" in elem and tokenizer.eos_token_id:
|
||||||
|
token_ids = token_ids + [tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
|
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@ -159,23 +129,39 @@ class Llama2Template(Template):
|
|||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
history: List[Tuple[str, str]]
|
tool: str,
|
||||||
|
cutoff_len: int
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: bos + prefix + query resp + eos
|
Turn 0: system + query resp + eos
|
||||||
Turn t: bos + query resp + eos
|
Turn t: sep + query resp + eos
|
||||||
"""
|
"""
|
||||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
system = system or self.system
|
||||||
encoded_pairs = []
|
encoded_messages = []
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
for i, message in enumerate(messages):
|
||||||
if turn_idx == 0: # llama2 template has no sep_ids
|
elements = []
|
||||||
query = self.prefix[0].replace("{{system}}", system) + query
|
system_text = ""
|
||||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
if i == 0 and (system or tool):
|
||||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
tool_text = self.format_tool(content=tool)[0] if tool else ""
|
||||||
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
system_text = self.format_system(content=(system + tool_text))[0]
|
||||||
return encoded_pairs
|
elif i > 0 and i % 2 == 0:
|
||||||
|
elements += self.separator
|
||||||
|
|
||||||
|
if message["role"] == Role.USER:
|
||||||
|
elements += self.format_user(content=system_text + message["content"], idx=str(i // 2))
|
||||||
|
elif message["role"] == Role.ASSISTANT:
|
||||||
|
elements += self.format_assistant(content=message["content"])
|
||||||
|
elif message["role"] == Role.OBSERVATION:
|
||||||
|
elements += self.format_observation(content=message["content"])
|
||||||
|
elif message["role"] == Role.FUNCTION:
|
||||||
|
elements += self.format_function(content=message["content"])
|
||||||
|
|
||||||
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
|
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
|
||||||
|
|
||||||
templates: Dict[str, Template] = {}
|
templates: Dict[str, Template] = {}
|
||||||
@ -183,23 +169,33 @@ templates: Dict[str, Template] = {}
|
|||||||
|
|
||||||
def register_template(
|
def register_template(
|
||||||
name: str,
|
name: str,
|
||||||
prefix: List[Union[str, Dict[str, str]]],
|
format_user: Optional[Callable] = None,
|
||||||
prompt: List[Union[str, Dict[str, str]]],
|
format_assistant: Optional[Callable] = None,
|
||||||
system: str,
|
format_system: Optional[Callable] = None,
|
||||||
sep: List[Union[str, Dict[str, str]]],
|
format_tool: Optional[Callable] = None,
|
||||||
|
format_observation: Optional[Callable] = None,
|
||||||
|
format_function: Optional[Callable] = None,
|
||||||
|
system: Optional[str] = "",
|
||||||
|
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
|
||||||
stop_words: Optional[List[str]] = [],
|
stop_words: Optional[List[str]] = [],
|
||||||
use_history: Optional[bool] = True,
|
|
||||||
efficient_eos: Optional[bool] = False,
|
efficient_eos: Optional[bool] = False,
|
||||||
replace_eos: Optional[bool] = False
|
replace_eos: Optional[bool] = False
|
||||||
) -> None:
|
) -> None:
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
prefix=prefix,
|
format_user=format_user or StringFormatter(container=["{{content}}"]),
|
||||||
prompt=prompt,
|
format_assistant=format_assistant or StringFormatter(container=[
|
||||||
|
"{{content}}", {"eos_token"}
|
||||||
|
]),
|
||||||
|
format_system=format_system or StringFormatter(container=["{{content}}"]),
|
||||||
|
format_tool=format_tool or ToolFormatter(type="default"),
|
||||||
|
format_observation=format_observation or format_user,
|
||||||
|
format_function=format_function or FunctionFormatter(container=[
|
||||||
|
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
|
||||||
|
]),
|
||||||
system=system,
|
system=system,
|
||||||
sep=sep,
|
separator=separator,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
use_history=use_history,
|
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos
|
replace_eos=replace_eos
|
||||||
)
|
)
|
||||||
@ -244,17 +240,14 @@ def get_template_and_fix_tokenizer(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="alpaca",
|
name="alpaca",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"### Instruction:\n{{content}}\n\n### Response:\n"
|
||||||
],
|
]),
|
||||||
prompt=[
|
|
||||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"Below is an instruction that describes a task. "
|
"Below is an instruction that describes a task. "
|
||||||
"Write a response that appropriately completes the request."
|
"Write a response that appropriately completes the request."
|
||||||
),
|
),
|
||||||
sep=[
|
separator=[
|
||||||
"\n\n"
|
"\n\n"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -262,17 +255,14 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"Human: {{content}}###Assistant:"
|
||||||
],
|
]),
|
||||||
prompt=[
|
|
||||||
"Human: {{query}}###Assistant:"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"A chat between a curious human and an artificial intelligence assistant. "
|
"A chat between a curious human and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
),
|
),
|
||||||
sep=[
|
separator=[
|
||||||
"###"
|
"###"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -284,46 +274,32 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
{"token": "<reserved_102>"},
|
||||||
],
|
"{{content}}",
|
||||||
prompt=[
|
{"token": "<reserved_103>"}
|
||||||
{"token": "<reserved_102>"}, # user token
|
]),
|
||||||
"{{query}}",
|
|
||||||
{"token": "<reserved_103>"} # assistant token
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[],
|
|
||||||
efficient_eos=True
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan2",
|
name="baichuan2",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
{"token": "<reserved_106>"},
|
||||||
],
|
"{{content}}",
|
||||||
prompt=[
|
{"token": "<reserved_107>"}
|
||||||
{"token": "<reserved_106>"}, # user token
|
]),
|
||||||
"{{query}}",
|
|
||||||
{"token": "<reserved_107>"} # assistant token
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[],
|
|
||||||
efficient_eos=True
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"Human: {{content}}\n\nBelle: "
|
||||||
],
|
]),
|
||||||
prompt=[
|
separator=[
|
||||||
"Human: {{query}}\n\nBelle: "
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[
|
|
||||||
"\n\n"
|
"\n\n"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -331,31 +307,25 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="bluelm",
|
name="bluelm",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
{"token": "[|Human|]:"},
|
{"token": "[|Human|]:"},
|
||||||
"{{query}}",
|
"{{content}}",
|
||||||
{"token": "[|AI|]:"}
|
{"token": "[|AI|]:"}
|
||||||
],
|
])
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
|
"[Round {{idx}}]\n\n问:{{content}}\n\n答:"
|
||||||
|
]),
|
||||||
|
format_system=StringFormatter(container=[
|
||||||
{"token": "[gMASK]"},
|
{"token": "[gMASK]"},
|
||||||
{"token": "sop"},
|
{"token": "sop"},
|
||||||
"{{system}}"
|
"{{content}}"
|
||||||
],
|
]),
|
||||||
prompt=[
|
separator=[
|
||||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[
|
|
||||||
"\n\n"
|
"\n\n"
|
||||||
],
|
],
|
||||||
efficient_eos=True
|
efficient_eos=True
|
||||||
@ -364,53 +334,35 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
{"token": "[gMASK]"},
|
|
||||||
{"token": "sop"},
|
|
||||||
{"token": "<|system|>"},
|
|
||||||
"\n",
|
|
||||||
"{{system}}"
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
{"token": "<|user|>"},
|
{"token": "<|user|>"},
|
||||||
"\n",
|
"\n",
|
||||||
"{{query}}",
|
"{{content}}",
|
||||||
{"token": "<|assistant|>"},
|
|
||||||
"\n" # add an extra newline to avoid error in ChatGLM's process_response method
|
|
||||||
],
|
|
||||||
system=(
|
|
||||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
|
||||||
"Follow the user's instructions carefully. Respond using markdown."
|
|
||||||
),
|
|
||||||
sep=[],
|
|
||||||
stop_words=[
|
|
||||||
"<|user|>",
|
|
||||||
"<|observation|>"
|
|
||||||
],
|
|
||||||
efficient_eos=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
|
||||||
name="chatglm3_raw", # the raw template for tool tuning
|
|
||||||
prefix=[
|
|
||||||
{"token": "[gMASK]"},
|
|
||||||
{"token": "sop"},
|
|
||||||
{"token": "<|system|>"},
|
|
||||||
"\n",
|
|
||||||
"{{system}}"
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
{"token": "<|user|>"},
|
|
||||||
"\n",
|
|
||||||
"{{query}}",
|
|
||||||
{"token": "<|assistant|>"}
|
{"token": "<|assistant|>"}
|
||||||
],
|
]),
|
||||||
|
format_assistant=StringFormatter(container=[
|
||||||
|
"\n"
|
||||||
|
"{{content}}"
|
||||||
|
]),
|
||||||
|
format_system=StringFormatter(container=[
|
||||||
|
{"token": "[gMASK]"},
|
||||||
|
{"token": "sop"},
|
||||||
|
{"token": "<|system|>"},
|
||||||
|
"\n",
|
||||||
|
"{{content}}"
|
||||||
|
]),
|
||||||
|
format_observation=StringFormatter(container=[
|
||||||
|
{"token": "<|observation|>"},
|
||||||
|
"\n",
|
||||||
|
"{{content}}"
|
||||||
|
]),
|
||||||
|
format_function=FunctionFormatter(container=[
|
||||||
|
"{{name}}\n{{arguments}}"
|
||||||
|
]),
|
||||||
system=(
|
system=(
|
||||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||||
"Follow the user's instructions carefully. Respond using markdown."
|
"Follow the user's instructions carefully. Respond using markdown."
|
||||||
),
|
),
|
||||||
sep=[],
|
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"<|user|>",
|
"<|user|>",
|
||||||
"<|observation|>"
|
"<|observation|>"
|
||||||
@ -421,47 +373,34 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="codegeex2",
|
name="codegeex2",
|
||||||
prefix=[
|
format_system=StringFormatter(container=[
|
||||||
{"token": "[gMASK]"},
|
{"token": "[gMASK]"},
|
||||||
{"token": "sop"},
|
{"token": "sop"},
|
||||||
"{{system}}"
|
"{{content}}"
|
||||||
],
|
])
|
||||||
prompt=[
|
|
||||||
"{{query}}"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"User: {{content}}\n\nAssistant:"
|
||||||
],
|
])
|
||||||
prompt=[
|
|
||||||
"User: {{query}}\n\nAssistant:"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="deepseekcoder",
|
name="deepseekcoder",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"### Instruction:\n{{content}}\n### Response:\n"
|
||||||
],
|
]),
|
||||||
prompt=[
|
|
||||||
"### Instruction:\n{{query}}\n### Response:\n"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||||
"For politically sensitive questions, security and privacy issues, "
|
"For politically sensitive questions, security and privacy issues, "
|
||||||
"and other non-computer science questions, you will refuse to answer\n"
|
"and other non-computer science questions, you will refuse to answer\n"
|
||||||
),
|
),
|
||||||
sep=[
|
separator=[
|
||||||
"\n",
|
"\n",
|
||||||
{"token": "<|EOT|>"},
|
{"token": "<|EOT|>"},
|
||||||
"\n"
|
"\n"
|
||||||
@ -475,17 +414,14 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="default",
|
name="default",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"Human: {{content}}\nAssistant: "
|
||||||
],
|
]),
|
||||||
prompt=[
|
|
||||||
"Human: {{query}}\nAssistant:"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"A chat between a curious user and an artificial intelligence assistant. "
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||||
),
|
),
|
||||||
sep=[
|
separator=[
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -493,14 +429,10 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="falcon",
|
name="falcon",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"User: {{content}}\nFalcon:"
|
||||||
],
|
]),
|
||||||
prompt=[
|
separator=[
|
||||||
"User: {{query}}\nFalcon:"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[
|
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
efficient_eos=True
|
efficient_eos=True
|
||||||
@ -509,16 +441,12 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"<|User|>:{{content}}",
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"<|User|>:{{query}}",
|
|
||||||
{"token": "<eoh>"},
|
{"token": "<eoh>"},
|
||||||
"\n<|Bot|>:"
|
"\n<|Bot|>:"
|
||||||
],
|
]),
|
||||||
system="",
|
separator=[
|
||||||
sep=[
|
|
||||||
{"token": "<eoa>"},
|
{"token": "<eoa>"},
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
@ -529,14 +457,44 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="intern2",
|
||||||
|
format_user=StringFormatter(container=[
|
||||||
|
{"token": "[UNUSED_TOKEN_146]"},
|
||||||
|
"user\n{{content}}",
|
||||||
|
{"token": "[UNUSED_TOKEN_145]"},
|
||||||
|
"\n",
|
||||||
|
{"token": "[UNUSED_TOKEN_146]"},
|
||||||
|
"assistant\n"
|
||||||
|
]),
|
||||||
|
format_system=StringFormatter(container=[
|
||||||
|
{"token": "[UNUSED_TOKEN_146]"},
|
||||||
|
"system\n{{content}}",
|
||||||
|
{"token": "[UNUSED_TOKEN_145]"},
|
||||||
|
"\n"
|
||||||
|
]),
|
||||||
|
system=(
|
||||||
|
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||||
|
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||||
|
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||||
|
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||||
|
"by the user such as English and 中文."
|
||||||
|
),
|
||||||
|
separator=[
|
||||||
|
{"token": "[UNUSED_TOKEN_145]"},
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
stop_words=[
|
||||||
|
"[UNUSED_TOKEN_145]"
|
||||||
|
],
|
||||||
|
efficient_eos=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
prefix=[
|
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
||||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"[INST] {{query}} [/INST]"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"You are a helpful, respectful and honest assistant. "
|
"You are a helpful, respectful and honest assistant. "
|
||||||
"Always answer as helpfully as possible, while being safe. "
|
"Always answer as helpfully as possible, while being safe. "
|
||||||
@ -546,49 +504,32 @@ register_template(
|
|||||||
"If a question does not make any sense, or is not factually coherent, "
|
"If a question does not make any sense, or is not factually coherent, "
|
||||||
"explain why instead of answering something not correct. "
|
"explain why instead of answering something not correct. "
|
||||||
"If you don't know the answer to a question, please don't share false information."
|
"If you don't know the answer to a question, please don't share false information."
|
||||||
),
|
)
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2_zh",
|
name="llama2_zh",
|
||||||
prefix=[
|
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
||||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
],
|
system="You are a helpful assistant. 你是一个乐于助人的助手。"
|
||||||
prompt=[
|
|
||||||
"[INST] {{query}} [/INST]"
|
|
||||||
],
|
|
||||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
prefix=[
|
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
|
||||||
"{{system}}"
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"[INST] {{query}} [/INST]"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"GPT4 Correct User: {{content}}",
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"GPT4 Correct User: {{query}}",
|
|
||||||
{"token": "<|end_of_turn|>"},
|
{"token": "<|end_of_turn|>"},
|
||||||
"GPT4 Correct Assistant:"
|
"GPT4 Correct Assistant:"
|
||||||
],
|
]),
|
||||||
system="",
|
separator=[
|
||||||
sep=[
|
|
||||||
{"token": "<|end_of_turn|>"}
|
{"token": "<|end_of_turn|>"}
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -600,14 +541,14 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"<|im_start|>system\n{{system}}<|im_end|>"
|
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
],
|
]),
|
||||||
prompt=[
|
format_system=StringFormatter(container=[
|
||||||
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
|
"<|im_start|>system\n{{content}}<|im_end|>\n"
|
||||||
],
|
]),
|
||||||
system="You are a helpful assistant.",
|
system="You are a helpful assistant.",
|
||||||
sep=[
|
separator=[
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -619,32 +560,28 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="solar",
|
name="solar",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"### User:\n{{content}}\n\n### Assistant:\n"
|
||||||
],
|
])
|
||||||
prompt=[
|
|
||||||
"### User:\n{{query}}\n\n### Assistant:\n"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
{"token": "<|system|>"},
|
|
||||||
"\n{{system}}",
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
{"token": "<|user|>"},
|
{"token": "<|user|>"},
|
||||||
"\n{{query}}",
|
"\n{{content}}",
|
||||||
{"token": "<|end|>"},
|
{"token": "<|end|>"},
|
||||||
"\n",
|
"\n",
|
||||||
{"token": "<|assistant|>"}
|
{"token": "<|assistant|>"}
|
||||||
],
|
]),
|
||||||
system="",
|
format_system=StringFormatter(container=[
|
||||||
sep=[
|
{"token": "<|system|>"},
|
||||||
|
"\n{{content}}",
|
||||||
|
{"token": "<|end|>"},
|
||||||
|
"\n"
|
||||||
|
]),
|
||||||
|
separator=[
|
||||||
{"token": "<|end|>"},
|
{"token": "<|end|>"},
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
@ -656,75 +593,55 @@ register_template(
|
|||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="vanilla",
|
name="vanilla"
|
||||||
prefix=[],
|
|
||||||
prompt=[
|
|
||||||
"{{query}}"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[],
|
|
||||||
use_history=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="vicuna",
|
name="vicuna",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"USER: {{content}} ASSISTANT:"
|
||||||
],
|
]),
|
||||||
prompt=[
|
|
||||||
"USER: {{query}} ASSISTANT:"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"A chat between a curious user and an artificial intelligence assistant. "
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
),
|
)
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="xuanyuan",
|
name="xuanyuan",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"Human: {{content}} Assistant:"
|
||||||
],
|
]),
|
||||||
prompt=[
|
|
||||||
"Human: {{query}} Assistant:"
|
|
||||||
],
|
|
||||||
system=(
|
system=(
|
||||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||||
),
|
)
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="xverse",
|
name="xverse",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"Human: {{content}}\n\nAssistant: "
|
||||||
],
|
])
|
||||||
prompt=[
|
|
||||||
"Human: {{query}}\n\nAssistant: "
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yayi",
|
name="yayi",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
{"token": "<|System|>"},
|
|
||||||
":\n{{system}}"
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
{"token": "<|Human|>"},
|
{"token": "<|Human|>"},
|
||||||
":\n{{query}}\n\n",
|
":\n{{content}}\n\n",
|
||||||
{"token": "<|YaYi|>"},
|
{"token": "<|YaYi|>"},
|
||||||
":"
|
":"
|
||||||
],
|
]),
|
||||||
|
format_system=StringFormatter(container=[
|
||||||
|
{"token": "<|System|>"},
|
||||||
|
":\n{{content}}\n\n"
|
||||||
|
]),
|
||||||
system=(
|
system=(
|
||||||
"You are a helpful, respectful and honest assistant named YaYi "
|
"You are a helpful, respectful and honest assistant named YaYi "
|
||||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||||
@ -736,7 +653,7 @@ register_template(
|
|||||||
"explain why instead of answering something not correct. "
|
"explain why instead of answering something not correct. "
|
||||||
"If you don't know the answer to a question, please don't share false information."
|
"If you don't know the answer to a question, please don't share false information."
|
||||||
),
|
),
|
||||||
sep=[
|
separator=[
|
||||||
"\n\n"
|
"\n\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -747,14 +664,10 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yi",
|
name="yi",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
],
|
]),
|
||||||
prompt=[
|
separator=[
|
||||||
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[
|
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -766,15 +679,11 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
"{{content}}",
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"{{query}}",
|
|
||||||
{"token": "<sep>"}
|
{"token": "<sep>"}
|
||||||
],
|
]),
|
||||||
system="",
|
separator=[
|
||||||
sep=[
|
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -786,30 +695,25 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"<|system|>\n{{system}}</s>",
|
"<|user|>\n{{content}}</s><|assistant|>"
|
||||||
],
|
]),
|
||||||
prompt=[
|
format_system=StringFormatter(container=[
|
||||||
"<|user|>\n{{query}}</s><|assistant|>"
|
"<|system|>\n{{content}}</s>",
|
||||||
],
|
]),
|
||||||
system="You are a friendly chatbot who always responds in the style of a pirate",
|
system="You are a friendly chatbot who always responds in the style of a pirate"
|
||||||
sep=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="ziya",
|
name="ziya",
|
||||||
prefix=[
|
format_user=StringFormatter(container=[
|
||||||
"{{system}}"
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
{"token": "<human>"},
|
{"token": "<human>"},
|
||||||
":{{query}}\n",
|
":{{content}}\n",
|
||||||
{"token": "<bot>"},
|
{"token": "<bot>"},
|
||||||
":"
|
":"
|
||||||
],
|
]),
|
||||||
system="",
|
separator=[
|
||||||
sep=[
|
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from enum import Enum, unique
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@ -12,6 +13,14 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class Role(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
OBSERVATION = "observation"
|
||||||
|
FUNCTION = "function"
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
if file_sha1 is None:
|
if file_sha1 is None:
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
||||||
|
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
||||||
|
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
||||||
|
max_source_len = data_args.cutoff_len - max_target_len
|
||||||
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.eval.evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Evaluator"]
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import tiktoken
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
@ -11,10 +10,11 @@ from typing import Any, Dict, List, Optional
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from llmtuner.eval.template import get_eval_template
|
from .template import get_eval_template
|
||||||
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
from ..extras.constants import CHOICES, SUBJECTS
|
||||||
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
from ..hparams import get_eval_args
|
||||||
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
@ -26,15 +26,9 @@ class Evaluator:
|
|||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = self._encode_choices()
|
self.choice_inputs = [self.tokenizer.encode(
|
||||||
|
self.eval_template.prefix + ch, add_special_tokens=False
|
||||||
def _encode_choices(self) -> List[int]:
|
)[-1] for ch in CHOICES]
|
||||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=False)
|
|
||||||
|
|
||||||
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.constants import CHOICES
|
from ..extras.constants import CHOICES
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@ -44,7 +44,7 @@ class EvalTemplate:
|
|||||||
return query.strip(), resp, history
|
return query.strip(), resp, history
|
||||||
|
|
||||||
|
|
||||||
eval_templates: Dict[str, EvalTemplate] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_eval_template(
|
def register_eval_template(
|
||||||
@ -62,7 +62,7 @@ def register_eval_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_template(name: str) -> EvalTemplate:
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
eval_template = eval_templates.get(name, None)
|
eval_template = eval_templates.get(name, None)
|
||||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||||
return eval_template
|
return eval_template
|
||||||
|
@ -6,9 +6,9 @@ from datetime import timedelta
|
|||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
from .constants import LOG_FILE_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from .logging import get_logger
|
||||||
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
from .misc import fix_valuehead_checkpoint
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -5,6 +5,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
CHOICES = ["A", "B", "C", "D"]
|
CHOICES = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
DATA_CONFIG = "dataset_info.json"
|
||||||
|
|
||||||
DEFAULT_MODULE = defaultdict(str)
|
DEFAULT_MODULE = defaultdict(str)
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = defaultdict(str)
|
DEFAULT_TEMPLATE = defaultdict(str)
|
||||||
|
@ -13,8 +13,8 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
|
@ -10,7 +10,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
print("Please upgrade `transformers`.")
|
print("Please upgrade `transformers`.")
|
||||||
|
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from ..packages import is_flash_attn2_available
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn2_available():
|
if is_flash_attn2_available():
|
||||||
|
@ -4,8 +4,8 @@ import json
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from .logging import get_logger
|
||||||
from llmtuner.extras.packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
if is_matplotlib_available():
|
if is_matplotlib_available():
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
|
|||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
from .parser import get_train_args, get_infer_args, get_eval_args
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataArguments",
|
||||||
|
"EvaluationArguments",
|
||||||
|
"FinetuningArguments",
|
||||||
|
"GeneratingArguments",
|
||||||
|
"ModelArguments",
|
||||||
|
"get_train_args",
|
||||||
|
"get_infer_args",
|
||||||
|
"get_eval_args"
|
||||||
|
]
|
||||||
|
@ -1,40 +1,7 @@
|
|||||||
import os
|
from typing import Literal, Optional
|
||||||
import json
|
|
||||||
from typing import List, Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
DATA_CONFIG = "dataset_info.json"
|
|
||||||
|
|
||||||
|
|
||||||
def use_modelscope() -> bool:
|
|
||||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetAttr:
|
|
||||||
|
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
|
||||||
dataset_name: Optional[str] = None
|
|
||||||
dataset_sha1: Optional[str] = None
|
|
||||||
subset: Optional[str] = None
|
|
||||||
folder: Optional[str] = None
|
|
||||||
ranking: Optional[bool] = False
|
|
||||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
|
||||||
|
|
||||||
prompt: Optional[str] = "instruction"
|
|
||||||
query: Optional[str] = "input"
|
|
||||||
response: Optional[str] = "output"
|
|
||||||
history: Optional[str] = None
|
|
||||||
messages: Optional[str] = "conversations"
|
|
||||||
role: Optional[str] = "from"
|
|
||||||
content: Optional[str] = "value"
|
|
||||||
system: Optional[str] = None
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self.dataset_name
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
r"""
|
r"""
|
||||||
@ -126,64 +93,3 @@ class DataArguments:
|
|||||||
|
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
|
||||||
self.seed = seed
|
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
|
||||||
try:
|
|
||||||
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
|
||||||
dataset_info = json.load(f)
|
|
||||||
except Exception as err:
|
|
||||||
if self.dataset is not None:
|
|
||||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
|
||||||
dataset_info = None
|
|
||||||
|
|
||||||
if self.interleave_probs is not None:
|
|
||||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
|
||||||
for name in dataset_names:
|
|
||||||
if name not in dataset_info:
|
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
|
||||||
|
|
||||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
|
||||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
|
||||||
|
|
||||||
if has_hf_url or has_ms_url:
|
|
||||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"ms_hub",
|
|
||||||
dataset_name=dataset_info[name]["ms_hub_url"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"hf_hub",
|
|
||||||
dataset_name=dataset_info[name]["hf_hub_url"]
|
|
||||||
)
|
|
||||||
elif "script_url" in dataset_info[name]:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"script",
|
|
||||||
dataset_name=dataset_info[name]["script_url"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"file",
|
|
||||||
dataset_name=dataset_info[name]["file_name"],
|
|
||||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
|
||||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
|
||||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
|
||||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
|
||||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
|
||||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
|
||||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
|
||||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
|
||||||
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
|
|
||||||
|
|
||||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
|
||||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
|
||||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
|
||||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
|
||||||
self.dataset_list.append(dataset_attr)
|
|
||||||
|
@ -43,13 +43,5 @@ class EvaluationArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
task_available = []
|
|
||||||
for folder in os.listdir(self.task_dir):
|
|
||||||
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
|
||||||
task_available.append(folder)
|
|
||||||
|
|
||||||
if self.task not in task_available:
|
|
||||||
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
|
||||||
|
|
||||||
if self.save_dir is not None and os.path.exists(self.save_dir):
|
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||||
raise ValueError("`save_dir` already exists, use another one.")
|
raise ValueError("`save_dir` already exists, use another one.")
|
||||||
|
@ -8,14 +8,12 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.hparams import (
|
from .data_args import DataArguments
|
||||||
ModelArguments,
|
from .evaluation_args import EvaluationArguments
|
||||||
DataArguments,
|
from .finetuning_args import FinetuningArguments
|
||||||
EvaluationArguments,
|
from .generating_args import GeneratingArguments
|
||||||
FinetuningArguments,
|
from .model_args import ModelArguments
|
||||||
GeneratingArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@ -107,8 +105,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
|
||||||
# Check arguments
|
# Check arguments
|
||||||
data_args.init_for_training(training_args.seed)
|
|
||||||
|
|
||||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
# Level: loader > adapter > parser, utils
|
from .loader import load_model_and_tokenizer
|
||||||
|
from .utils import dispatch_model, get_modelcard_args, load_valuehead_params
|
||||||
|
|
||||||
from llmtuner.model.loader import load_model_and_tokenizer
|
|
||||||
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"]
|
||||||
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params
|
|
||||||
|
@ -3,12 +3,12 @@ from typing import TYPE_CHECKING
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.model.utils import find_all_linear_modules
|
from .utils import find_all_linear_modules
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from ..hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -4,15 +4,15 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||||
from llmtuner.model.adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
|
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
|
||||||
from llmtuner.model.utils import load_valuehead_params, register_autoclass
|
from .utils import load_valuehead_params, register_autoclass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from ..hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -10,15 +10,15 @@ from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTra
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
|
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from ..extras.packages import is_flash_attn2_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import get_current_device
|
from ..extras.misc import get_current_device
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.train.tuner import export_model, run_exp
|
from .tuner import export_model, run_exp
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["export_model", "run_exp"]
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.train.dpo.workflow import run_dpo
|
from .workflow import run_dpo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_dpo"]
|
||||||
|
@ -5,7 +5,7 @@ from transformers import BatchEncoding, Trainer
|
|||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
@ -3,18 +3,18 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
from ...train.dpo.collator import DPODataCollatorWithPadding
|
||||||
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
from ...train.dpo.trainer import CustomDPOTrainer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
|
from ...train.utils import create_modelcard_and_push, create_ref_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
from ...hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_dpo(
|
def run_dpo(
|
||||||
@ -24,9 +24,8 @@ def run_dpo(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = DPODataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8,
|
pad_to_multiple_of=8,
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.train.ppo.workflow import run_ppo
|
from .workflow import run_ppo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_ppo"]
|
||||||
|
@ -13,15 +13,15 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
|||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback
|
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llmtuner.extras.packages import is_requests_available
|
from ...extras.packages import is_requests_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
@ -7,17 +7,17 @@ from typing import TYPE_CHECKING, Optional, List
|
|||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset
|
from ...data import get_dataset
|
||||||
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.utils import create_ref_model, create_reward_model
|
from ...train.utils import create_ref_model, create_reward_model
|
||||||
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
from ...train.ppo.trainer import CustomPPOTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
@ -28,9 +28,8 @@ def run_ppo(
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="ppo")
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.train.pt.workflow import run_pt
|
from .workflow import run_pt
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_pt"]
|
||||||
|
@ -4,14 +4,14 @@ import math
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
@ -21,9 +21,8 @@ def run_pt(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="pt")
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.train.rm.workflow import run_rm
|
from .workflow import run_rm
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_rm"]
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
@ -3,19 +3,19 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.train.rm.metric import compute_accuracy
|
from ...train.rm.metric import compute_accuracy
|
||||||
from llmtuner.train.rm.trainer import PairwiseTrainer
|
from ...train.rm.trainer import PairwiseTrainer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from ...hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_rm(
|
def run_rm(
|
||||||
@ -25,9 +25,8 @@ def run_rm(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.train.sft.workflow import run_sft
|
from .workflow import run_sft
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_sft"]
|
||||||
|
@ -2,8 +2,8 @@ import numpy as np
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.packages import (
|
from ...extras.packages import (
|
||||||
is_jieba_available, is_nltk_available, is_rouge_available
|
is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
@ -3,18 +3,19 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from ...extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.sft.metric import ComputeMetrics
|
from ...train.sft.metric import ComputeMetrics
|
||||||
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
|
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
@ -25,9 +26,8 @@ def run_sft(
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="sft")
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation
|
tokenizer.padding_side = "left" # use left-padding in generation
|
||||||
|
@ -2,14 +2,15 @@ import torch
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
|
from ..hparams import get_train_args, get_infer_args
|
||||||
from llmtuner.train.pt import run_pt
|
from ..model import load_model_and_tokenizer
|
||||||
from llmtuner.train.sft import run_sft
|
from .pt import run_pt
|
||||||
from llmtuner.train.rm import run_rm
|
from .sft import run_sft
|
||||||
from llmtuner.train.ppo import run_ppo
|
from .rm import run_rm
|
||||||
from llmtuner.train.dpo import run_dpo
|
from .ppo import run_ppo
|
||||||
|
from .dpo import run_dpo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from ..hparams import ModelArguments, FinetuningArguments
|
||||||
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from llmtuner.webui.interface import create_ui, create_web_demo
|
from .interface import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["create_ui", "create_web_demo"]
|
||||||
|
@ -2,14 +2,14 @@ import gradio as gr
|
|||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from llmtuner.hparams import GeneratingArguments
|
from ..hparams import GeneratingArguments
|
||||||
from llmtuner.webui.common import get_save_dir
|
from .common import get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from .locales import ALERTS
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.webui.manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
@ -5,7 +5,8 @@ from collections import defaultdict
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
||||||
|
|
||||||
from llmtuner.extras.constants import (
|
from ..extras.constants import (
|
||||||
|
DATA_CONFIG,
|
||||||
DEFAULT_MODULE,
|
DEFAULT_MODULE,
|
||||||
DEFAULT_TEMPLATE,
|
DEFAULT_TEMPLATE,
|
||||||
PEFT_METHODS,
|
PEFT_METHODS,
|
||||||
@ -13,8 +14,7 @@ from llmtuner.extras.constants import (
|
|||||||
TRAINING_STAGES,
|
TRAINING_STAGES,
|
||||||
DownloadSource
|
DownloadSource
|
||||||
)
|
)
|
||||||
from llmtuner.extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope
|
||||||
from llmtuner.hparams.data_args import DATA_CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
from llmtuner.webui.components.top import create_top
|
from .top import create_top
|
||||||
from llmtuner.webui.components.train import create_train_tab
|
from .train import create_train_tab
|
||||||
from llmtuner.webui.components.eval import create_eval_tab
|
from .eval import create_eval_tab
|
||||||
from llmtuner.webui.components.infer import create_infer_tab
|
from .infer import create_infer_tab
|
||||||
from llmtuner.webui.components.export import create_export_tab
|
from .export import create_export_tab
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
from .chatbot import create_chat_box
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
|
||||||
|
]
|
||||||
|
@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_chat_box(
|
def create_chat_box(
|
||||||
|
@ -3,7 +3,7 @@ import json
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||||
|
|
||||||
from llmtuner.webui.common import DATA_CONFIG
|
from ...extras.constants import DATA_CONFIG
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
from ..common import list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from .data import create_preview_box
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||||
|
|
||||||
from llmtuner.train import export_model
|
from ...train import export_model
|
||||||
from llmtuner.webui.common import get_save_dir
|
from ..common import get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from ..locales import ALERTS
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
from .chatbot import create_chat_box
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from llmtuner.data.template import templates
|
from ...data import templates
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config
|
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||||
from llmtuner.webui.utils import can_quantize
|
from ..utils import can_quantize
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
@ -2,14 +2,15 @@ import gradio as gr
|
|||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
from transformers.trainer_utils import SchedulerType
|
from transformers.trainer_utils import SchedulerType
|
||||||
|
|
||||||
from llmtuner.extras.constants import TRAINING_STAGES
|
from ...extras.constants import TRAINING_STAGES
|
||||||
from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR
|
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from ..components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import gen_plot
|
from ..utils import gen_plot
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
@ -2,12 +2,12 @@ import gradio as gr
|
|||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import Any, Dict, Generator, Optional
|
from typing import Any, Dict, Generator, Optional
|
||||||
|
|
||||||
from llmtuner.webui.chatter import WebChatModel
|
from .chatter import WebChatModel
|
||||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
from .common import get_model_path, list_dataset, load_config
|
||||||
from llmtuner.webui.locales import LOCALES
|
from .locales import LOCALES
|
||||||
from llmtuner.webui.manager import Manager
|
from .manager import Manager
|
||||||
from llmtuner.webui.runner import Runner
|
from .runner import Runner
|
||||||
from llmtuner.webui.utils import get_time
|
from .utils import get_time
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
|
@ -2,7 +2,7 @@ import gradio as gr
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.webui.components import (
|
from .components import (
|
||||||
create_top,
|
create_top,
|
||||||
create_train_tab,
|
create_train_tab,
|
||||||
create_eval_tab,
|
create_eval_tab,
|
||||||
@ -10,9 +10,9 @@ from llmtuner.webui.components import (
|
|||||||
create_export_tab,
|
create_export_tab,
|
||||||
create_chat_box
|
create_chat_box
|
||||||
)
|
)
|
||||||
from llmtuner.webui.common import save_config
|
from .common import save_config
|
||||||
from llmtuner.webui.css import CSS
|
from .css import CSS
|
||||||
from llmtuner.webui.engine import Engine
|
from .engine import Engine
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
||||||
|
@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import TRAINING_STAGES
|
from ..extras.constants import TRAINING_STAGES
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from ..extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import get_device_count, torch_gc
|
from ..extras.misc import get_device_count, torch_gc
|
||||||
from llmtuner.train import run_exp
|
from ..train import run_exp
|
||||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
from .common import get_module, get_save_dir, load_config
|
||||||
from llmtuner.webui.locales import ALERTS
|
from .locales import ALERTS
|
||||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
from .utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.webui.manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
class Runner:
|
class Runner:
|
||||||
|
@ -4,12 +4,12 @@ import gradio as gr
|
|||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llmtuner.extras.packages import is_matplotlib_available
|
from ..extras.packages import is_matplotlib_available
|
||||||
from llmtuner.extras.ploting import smooth
|
from ..extras.ploting import smooth
|
||||||
from llmtuner.webui.common import get_save_dir
|
from .common import get_save_dir
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
|
|
||||||
if is_matplotlib_available():
|
if is_matplotlib_available():
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the InternLM2 model in the same format as LLaMA2.
|
# Converts the InternLM2 model in the same format as LLaMA2.
|
||||||
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
|
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
|
||||||
|
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import fire
|
import fire
|
||||||
@ -43,19 +44,18 @@ def save_weight(
|
|||||||
llama2_state_dict[key.replace("output", "lm_head")] = value
|
llama2_state_dict[key.replace("output", "lm_head")] = value
|
||||||
elif "tok_embeddings" in key:
|
elif "tok_embeddings" in key:
|
||||||
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
|
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
|
||||||
elif "attention_norm" in key:
|
|
||||||
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
|
|
||||||
elif "wqkv" in key:
|
elif "wqkv" in key:
|
||||||
proj_size = value.size(0)
|
|
||||||
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
||||||
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
||||||
q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
||||||
kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
|
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
|
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
|
||||||
elif "wo" in key:
|
elif "wo" in key:
|
||||||
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
|
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
|
||||||
|
elif "attention_norm" in key:
|
||||||
|
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
|
||||||
elif "ffn_norm" in key:
|
elif "ffn_norm" in key:
|
||||||
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
|
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
|
||||||
elif "w1" in key:
|
elif "w1" in key:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user