support streaming data, fix #284 #274 #268

Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
hiyouga 2023-07-31 23:33:00 +08:00
parent 124f61b404
commit dd3f3e9749
28 changed files with 478 additions and 344 deletions

View File

@ -12,15 +12,15 @@
## Changelog ## Changelog
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model. [23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. [23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model. [23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model. [23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
@ -175,6 +176,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
@ -197,6 +199,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
@ -220,6 +223,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
@ -278,6 +282,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_eval \ --do_eval \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \ --output_dir path_to_eval_result \
@ -296,6 +301,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_predict \ --do_predict \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \ --output_dir path_to_predict_result \
@ -311,6 +317,7 @@ If you want to predict the samples with empty responses, please kindly fill the
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@ -322,6 +329,7 @@ Visit `http://localhost:8000/docs` for API documentation.
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@ -331,6 +339,7 @@ python src/cli_demo.py \
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@ -340,6 +349,7 @@ python src/web_demo.py \
```bash ```bash
python src/export_model.py \ python src/export_model.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export

View File

@ -12,15 +12,15 @@
## 更新日志 ## 更新日志
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--prompt_template llama2` 参数。 [23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 [23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--prompt_template baichuan` 参数。 [23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 [23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--prompt_template intern` 参数。 [23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。 [23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
@ -174,7 +175,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
@ -196,7 +198,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
@ -219,7 +222,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \ --stage ppo \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
@ -277,7 +281,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_eval \ --do_eval \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \ --output_dir path_to_eval_result \
@ -295,7 +300,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_predict \ --do_predict \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \ --output_dir path_to_predict_result \
@ -311,6 +317,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@ -322,6 +329,7 @@ python src/api_demo.py \
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@ -331,6 +339,7 @@ python src/cli_demo.py \
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@ -340,6 +349,7 @@ python src/web_demo.py \
```bash ```bash
python src/export_model.py \ python src/export_model.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export

View File

@ -1,42 +1,50 @@
import torch import torch
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer from llmtuner.tuner import load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
class ChatModel: class ChatModel:
def __init__( def __init__(
self, self,
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
generating_args: GeneratingArguments generating_args: "GeneratingArguments"
) -> None: ) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map from accelerate import dispatch_model
device_map = infer_auto_device_map(self.model) from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
self.model = dispatch_model(self.model, device_map) self.model = dispatch_model(self.model, device_map)
else: else:
self.model = self.model.cuda() self.model = self.model.cuda()
self.template = get_template(data_args.prompt_template) self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix or "" self.source_prefix = data_args.source_prefix
self.generating_args = generating_args self.generating_args = generating_args
def process_args( def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix prefix = prefix or self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt") prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device) inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0]) prompt_length = len(inputs["input_ids"][0])
@ -71,7 +79,11 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def chat( def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs) generation_output = self.model.generate(**gen_kwargs)
@ -82,7 +94,11 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def stream_chat( def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

View File

@ -1,40 +1,50 @@
import os import os
import hashlib import hashlib
from typing import List from typing import TYPE_CHECKING, List, Optional
from datasets import Dataset, concatenate_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def get_dataset( EXT2TYPE = {
model_args: ModelArguments,
data_args: DataArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv", "csv": "csv",
"json": "json", "json": "json",
"jsonl": "json", "jsonl": "json",
"txt": "text" "txt": "text"
} }
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> "Dataset":
max_samples = data_args.max_samples max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets all_datasets: List["Dataset"] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub": if dataset_attr.load_from == "hf_hub":
@ -47,60 +57,56 @@ def get_dataset(
data_path = None data_path = None
data_files: List[str] = [] data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None: if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None) data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else: else:
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None) data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else: else:
raise ValueError("File not found.") raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl." assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
else: else:
raise NotImplementedError raise NotImplementedError
raw_datasets = load_dataset( dataset = load_dataset(
data_path, data_path,
data_files=data_files, data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None use_auth_token=True if model_args.use_auth_token else None
) )
dataset = raw_datasets[data_args.split]
if max_samples is not None: if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples) max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp)) dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset) for column_name in ["prompt", "query", "response", "history"]: # align datasets
prefix_data = [dataset_attr.source_prefix] * len(dataset) if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
for column_name, target_name in [ dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
("prompt_column", "prompt"),
("query_column", "query"), if dataset_attr.source_prefix: # add prefix
("response_column", "response"), dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
all_datasets.append(dataset) all_datasets.append(dataset)
if len(data_args.dataset_list) == 1: if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
else: else:
all_datasets = concatenate_datasets(all_datasets) raise ValueError("Unknown mixing strategy.")
return all_datasets

View File

@ -1,65 +1,63 @@
from typing import Literal from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
from itertools import chain from itertools import chain
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template
from llmtuner.hparams import DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset( def preprocess_dataset(
dataset: Dataset, dataset: "Dataset",
tokenizer: PreTrainedTokenizer, tokenizer: "PreTrainedTokenizer",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset: ) -> "Dataset":
column_names = list(dataset.column_names or [])
template = get_template(data_args.template)
column_names = list(dataset.column_names) def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
prompt_template = get_template(data_args.prompt_template)
# support question with a single answer or multiple answers
def get_dialog(examples):
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]: query, response = examples["prompt"][i], examples["response"][i]
query, answer = examples["prompt"][i], examples["response"][i] query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
query = query + "\n" + examples["query"][i] if examples["query"][i] else query history = history if "history" in examples and examples["history"][i] else []
prefix = examples["prefix"][i] if examples["prefix"][i] else "" prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) yield query, response, history, prefix
yield dialog
def preprocess_pretrain_dataset(examples): def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>) # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
concatenated_ids = list(chain(*text_ids)) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_ids) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length - 1 block_size = data_args.max_source_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch # we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length # split by chunks of max_source_length
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] result = {
for i in range(0, total_length, block_size)] k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
return { for k, t in concatenated_examples.items()
"input_ids": result,
"labels": result.copy()
} }
result["labels"] = result["input_ids"].copy()
return result
def preprocess_supervised_dataset(examples): def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like: # for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length max_length = data_args.max_source_length + data_args.max_target_length
for dialog in get_dialog(examples): for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for i in range(len(dialog) // 2): for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@ -73,19 +71,20 @@ def preprocess_dataset(
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples): def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y` # build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for dialog in get_dialog(examples): for query, response, history, prefix in construct_example(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1] prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=answer, add_special_tokens=True) target_ids = tokenizer.encode(text=response, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@ -93,6 +92,7 @@ def preprocess_dataset(
target_ids = target_ids[:data_args.max_target_length] target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids) model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids) model_inputs["labels"].append(target_ids)
return model_inputs return model_inputs
@ -100,12 +100,12 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples): for query, response, history, prefix in construct_example(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1] prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@ -141,34 +141,44 @@ def preprocess_dataset(
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt": if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset preprocess_function = preprocess_pretrain_dataset
elif stage == "sft": elif stage == "sft" and not training_args.predict_with_generate:
if not training_args.predict_with_generate: dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset preprocess_function = preprocess_supervised_dataset
else:
preprocess_function = preprocess_unsupervised_dataset
elif stage == "rm": elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo": else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset preprocess_function = preprocess_unsupervised_dataset
with training_args.main_process_first(desc="dataset map pre-processing"): with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map( kwargs = {}
preprocess_function, if not data_args.streaming:
batched=True, kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset" desc="Running tokenizer on dataset"
) )
dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
if stage == "pt": if stage == "pt":
print_unsupervised_dataset_example(dataset[0]) print_unsupervised_dataset_example(next(iter(dataset)))
elif stage == "sft": elif stage == "sft":
print_supervised_dataset_example(dataset[0]) print_supervised_dataset_example(next(iter(dataset)))
elif stage == "rm": elif stage == "rm":
print_pairwise_dataset_example(dataset[0]) print_pairwise_dataset_example(next(iter(dataset)))
elif stage == "ppo": elif stage == "ppo":
print_unsupervised_dataset_example(dataset[0]) print_unsupervised_dataset_example(next(iter(dataset)))
return dataset return dataset

View File

@ -1,13 +1,12 @@
from typing import Dict from typing import TYPE_CHECKING, Dict
from datasets import Dataset
if TYPE_CHECKING:
from datasets import Dataset
def split_dataset( def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
if do_train: if do_train:
if dev_ratio > 1e-6: if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio) dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else: else:

View File

@ -1,16 +1,13 @@
import os import os
import json import json
import time import time
from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from transformers import ( from transformers import TrainerCallback
TrainerCallback,
TrainerControl, if TYPE_CHECKING:
TrainerState, from transformers import TrainingArguments, TrainerState, TrainerControl
TrainingArguments
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
@ -20,13 +17,13 @@ class LogCallback(TrainerCallback):
self.start_time = time.time() self.start_time = time.time()
self.tracker = {} self.tracker = {}
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
self.start_time = time.time() self.start_time = time.time()
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs. might take several inputs.
@ -35,7 +32,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
""" """
@ -43,7 +40,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """

View File

@ -1,12 +1,14 @@
import torch import torch
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class AverageMeter: class AverageMeter:
r""" r"""
@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0 trainable_params, all_param = 0, 0
for param in model.parameters(): for param in model.parameters():
num_params = param.numel() num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty # if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"): if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params all_param += num_params
if param.requires_grad: if param.requires_grad:
trainable_params += num_params trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param)) return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training( def prepare_model_for_training(
model: PreTrainedModel, model: "PreTrainedModel",
finetuning_type: str, finetuning_type: str,
output_layer_name: Optional[str] = "lm_head", output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True, use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel: ) -> "PreTrainedModel":
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
@ -84,6 +94,9 @@ def prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name): if finetuning_type != "full" and hasattr(model, output_layer_name):
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name) output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype input_dtype = output_layer.weight.dtype
@ -92,11 +105,8 @@ def prepare_model_for_training(
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32) return super().forward(x.to(input_dtype)).to(torch.float32)
new_output_layer = CastOutputToFloat(output_layer) setattr(model, output_layer_name, CastOutputToFloat(output_layer))
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
setattr(new_output_layer, "weight", output_layer.weight)
setattr(model, output_layer_name, new_output_layer)
return model return model

View File

@ -1,6 +1,6 @@
import os import os
import torch import torch
from typing import Dict, Optional from typing import Dict
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint from transformers.modeling_utils import load_sharded_checkpoint
@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]: def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict() state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {} filtered_state_dict = {}
for k, v in model.named_parameters(): for k, v in model.named_parameters():
if (not trainable_only) or v.requires_grad: if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach() filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict return filtered_state_dict

View File

@ -11,37 +11,46 @@ class Template:
use_history: bool use_history: bool
def get_prompt( def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "",
eos_token: Optional[str] = "</s>"
) -> str: ) -> str:
r""" r"""
Returns a string containing prompt without response. Returns a string containing prompt without response.
""" """
return "".join(self._format_example(query, history, prefix)) return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
def get_dialog( def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self,
) -> List[str]: query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
r""" r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. Returns a list containing prompt-response pairs.
""" """
return self._format_example(query, history, prefix) + [resp] result = self._format_example(query, history, prefix)
result[-1][-1] = resp
return result
def _format_example( def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self,
) -> List[str]: query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")] history = history + [(query, "")]
convs = [] convs = [
for turn_idx, (user_query, bot_resp) in enumerate(history): [(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
if turn_idx == 0: for turn_idx, (query_i, resp_i) in enumerate(history)
convs.append(prefix + self.prompt.format(query=user_query)) ]
convs.append(bot_resp) return convs
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs[:-1] # drop last
templates: Dict[str, Template] = {} templates: Dict[str, Template] = {}
@ -103,7 +112,7 @@ register_template(
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n", "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt=" [INST] {query} [/INST] ", prompt=" [INST] {query} [/INST] ",
sep="</s>", sep="",
use_history=True use_history=True
) )
@ -131,7 +140,7 @@ register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. " prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.", "The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ", prompt="USER: {query} ASSISTANT: ",
sep="</s>", sep="",
use_history=True use_history=True
) )
@ -216,7 +225,7 @@ register_template(
name="baichuan", name="baichuan",
prefix="", prefix="",
prompt="<reserved_102>{query}<reserved_103>", prompt="<reserved_102>{query}<reserved_103>",
sep="</s>", sep="",
use_history=True use_history=True
) )

View File

@ -1,6 +1,6 @@
import os import os
import json import json
from typing import List, Optional from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -16,10 +16,10 @@ class DatasetAttr:
return self.dataset_name return self.dataset_name
def __post_init__(self): def __post_init__(self):
self.prompt_column = "instruction" self.prompt = "instruction"
self.query_column = "input" self.query = "input"
self.response_column = "output" self.response = "output"
self.history_column = None self.history = None
@dataclass @dataclass
@ -27,8 +27,11 @@ class DataArguments:
""" """
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field( dataset: Optional[str] = field(
default="alpaca_zh", default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
@ -39,6 +42,18 @@ class DataArguments:
default="train", default="train",
metadata={"help": "Which dataset split to use for training and evaluation."} metadata={"help": "Which dataset split to use for training and evaluation."}
) )
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
)
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."} metadata={"help": "Overwrite the cached training and evaluation sets."}
@ -75,10 +90,6 @@ class DataArguments:
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
) )
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
def init_for_training(self): # support mixing multiple datasets def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")] dataset_names = [ds.strip() for ds in self.dataset.split(",")]
@ -111,9 +122,9 @@ class DataArguments:
dataset_attr.source_prefix = prefix_list[i] dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class GeneralArguments: class GeneralArguments:
""" """
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which stage we are going to perform.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
default="sft", default="sft",

View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
from typing import TYPE_CHECKING
from transformers.modeling_utils import PreTrainedModel
from peft import ( from peft import (
PeftModel, PeftModel,
TaskType, TaskType,
@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def init_adapter( def init_adapter(
model: PreTrainedModel, model: "PreTrainedModel",
model_args: ModelArguments, model_args: "ModelArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
is_mergeable: bool is_mergeable: bool
) -> PreTrainedModel: ) -> "PreTrainedModel":
r""" r"""
Initializes the adapters. Initializes the adapters.

View File

@ -1,6 +1,6 @@
import os import os
import torch import torch
from typing import Literal, Optional, Tuple from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
@ -16,11 +16,13 @@ from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -33,8 +35,8 @@ require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer( def load_model_and_tokenizer(
model_args: ModelArguments, model_args: "ModelArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
@ -141,6 +143,9 @@ def load_model_and_tokenizer(
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
return model, tokenizer return model, tokenizer

View File

@ -19,20 +19,39 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args( def get_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
if args is not None:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@ -73,13 +92,22 @@ def get_train_args(
if training_args.do_train and (not training_args.fp16): if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.") logger.warning("We recommend enable fp16 mixed precision training.")
if data_args.prompt_template == "default": if (
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: and finetuning_args.finetuning_type == "lora"
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") ):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
data_args.streaming = False
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
@ -106,17 +134,7 @@ def get_train_args(
def get_infer_args( def get_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
if args is not None:
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method." "Quantization is only compatible with the LoRA method."
@ -128,7 +146,4 @@ def get_infer_args(
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint." "Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args

View File

@ -1,16 +1,19 @@
import os import os
import torch import torch
from typing import Dict, Optional from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
from llmtuner.hparams import FinetuningArguments
if TYPE_CHECKING:
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -21,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
""" """
def __init__(self, finetuning_args: FinetuningArguments, **kwargs): def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self._remove_log() self._remove_log()
@ -42,31 +45,35 @@ class PeftTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model) model = unwrap_model(self.model)
state_dict = state_dict or get_state_dict(model)
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only) if isinstance(model, PreTrainedModelWrapper):
backbone_model = getattr(model, "pretrained_model") model_params, v_head_params = {}, {}
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) for name in state_dict.keys():
if name.startswith("pretrained_model."):
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
elif name.startswith("v_head."):
v_head_params[name.replace("v_head.", "")] = state_dict[name]
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
state_dict = model_params
model = model.pretrained_model
if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
model.config.use_cache = False
else: else:
backbone_model = model torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
else:
logger.warning("No model to save.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n") f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self): def _load_best_model(self):
@ -76,16 +83,15 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts. Subclass and override to inject custom behavior. It should not be directly used by external scripts.
""" """
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model) model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if isinstance(backbone_model, PeftModel): if isinstance(model, PreTrainedModelWrapper):
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter) model.v_head.load_state_dict(torch.load(
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
model.v_head.load_state_dict({ ))
"summary.weight": getattr(model, "reward_head_weight"), model = model.pretrained_model
"summary.bias": getattr(model, "reward_head_bias")
}) if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning else: # freeze/full-tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint) load_trainable_params(model, self.state.best_model_checkpoint)

View File

@ -2,21 +2,25 @@ import os
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl from transformers import TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor from llmtuner.extras.misc import AverageMeter, get_logits_processor
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -27,9 +31,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
""" """
def __init__( def __init__(
self, self,
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: List[LogCallback], callbacks: List["LogCallback"],
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)

View File

@ -1,11 +1,13 @@
import torch import torch
from typing import Dict, List, Literal, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict() valuehead_state_dict = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
def cast_layernorm_dtype( def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead, model: "AutoModelForCausalLMWithValueHead",
layer_norm_names: List[str] = LAYERNORM_NAMES, layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: ) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
layer_norm_state_dict = {} layer_norm_state_dict = {}

View File

@ -2,26 +2,30 @@
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math import math
from typing import TYPE_CHECKING
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import Optional, List from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_ppo( def run_ppo(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")

View File

@ -1,24 +1,27 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math import math
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt( def run_pt(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")

View File

@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
We generate 2 * n examples where the first n examples represent chosen examples and We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
for key in ("accept_ids", "reject_ids") for feature in features
]
return super().__call__(features) return super().__call__(features)

View File

@ -1,13 +1,15 @@
import os import os
import json import json
import torch import torch
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__) logger = get_logger(__name__)
@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
def compute_loss( def compute_loss(
self, self,
model: PreTrainedModel, model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer):
def save_predictions( def save_predictions(
self, self,
predict_results: PredictionOutput predict_results: "PredictionOutput"
) -> None: ) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.

View File

@ -2,25 +2,27 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm( def run_rm(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")

View File

@ -1,7 +1,6 @@
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from transformers.tokenization_utils import PreTrainedTokenizer
import jieba import jieba
from rouge_chinese import Rouge from rouge_chinese import Rouge
@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@dataclass @dataclass
class ComputeMetrics: class ComputeMetrics:
@ -16,7 +18,7 @@ class ComputeMetrics:
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
""" """
tokenizer: PreTrainedTokenizer tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r""" r"""

View File

@ -3,13 +3,15 @@ import json
import torch import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__) logger = get_logger(__name__)
@ -81,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
def save_predictions( def save_predictions(
self, self,
predict_results: PredictionOutput predict_results: "PredictionOutput"
) -> None: ) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.

View File

@ -1,25 +1,28 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_sft( def run_sft(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")

View File

@ -54,7 +54,7 @@ class WebChatModel(ChatModel):
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template, template=template,
source_prefix=source_prefix source_prefix=source_prefix
) )
super().__init__(*get_infer_args(args)) super().__init__(*get_infer_args(args))

View File

@ -111,7 +111,7 @@ class Runner:
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template, template=template,
source_prefix=source_prefix, source_prefix=source_prefix,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
@ -201,7 +201,7 @@ class Runner:
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template, template=template,
source_prefix=source_prefix, source_prefix=source_prefix,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),