mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 18:32:14 +08:00
Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
parent
5289530dcf
commit
63123a9098
16
README.md
16
README.md
@ -12,15 +12,15 @@
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model.
|
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
|
||||||
|
|
||||||
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||||
|
|
||||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||||
|
|
||||||
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||||
|
|
||||||
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model.
|
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
|
||||||
|
|
||||||
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
||||||
|
|
||||||
@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--output_dir path_to_pt_checkpoint \
|
--output_dir path_to_pt_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
@ -175,6 +176,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
@ -197,6 +199,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
@ -220,6 +223,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
@ -278,6 +282,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_eval_result \
|
--output_dir path_to_eval_result \
|
||||||
@ -296,6 +301,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_predict_result \
|
||||||
@ -311,6 +317,7 @@ If you want to predict the samples with empty responses, please kindly fill the
|
|||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
@ -322,6 +329,7 @@ Visit `http://localhost:8000/docs` for API documentation.
|
|||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
@ -331,6 +339,7 @@ python src/cli_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
@ -340,6 +349,7 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_export
|
||||||
|
|||||||
26
README_zh.md
26
README_zh.md
@ -12,15 +12,15 @@
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--prompt_template llama2` 参数。
|
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
|
||||||
|
|
||||||
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||||
|
|
||||||
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--prompt_template baichuan` 参数。
|
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
|
||||||
|
|
||||||
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||||
|
|
||||||
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--prompt_template intern` 参数。
|
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
|
||||||
|
|
||||||
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
|
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
|
||||||
|
|
||||||
@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--output_dir path_to_pt_checkpoint \
|
--output_dir path_to_pt_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
@ -174,7 +175,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_zh \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
@ -196,7 +198,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_zh \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
@ -219,7 +222,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_zh \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
@ -277,7 +281,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_zh \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_eval_result \
|
--output_dir path_to_eval_result \
|
||||||
@ -295,7 +300,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_zh \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_predict_result \
|
||||||
@ -311,6 +317,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
@ -322,6 +329,7 @@ python src/api_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
@ -331,6 +339,7 @@ python src/cli_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
@ -340,6 +349,7 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_export
|
||||||
|
|||||||
@ -1,42 +1,50 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
from llmtuner.tuner import load_model_and_tokenizer
|
from llmtuner.tuner import load_model_and_tokenizer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: GeneratingArguments
|
generating_args: "GeneratingArguments"
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
from accelerate import dispatch_model, infer_auto_device_map
|
from accelerate import dispatch_model
|
||||||
device_map = infer_auto_device_map(self.model)
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
|
||||||
self.model = dispatch_model(self.model, device_map)
|
self.model = dispatch_model(self.model, device_map)
|
||||||
else:
|
else:
|
||||||
self.model = self.model.cuda()
|
self.model = self.model.cuda()
|
||||||
|
|
||||||
self.template = get_template(data_args.prompt_template)
|
self.template = get_template(data_args.template)
|
||||||
self.source_prefix = data_args.source_prefix or ""
|
self.source_prefix = data_args.source_prefix
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
self,
|
||||||
|
query: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
**input_kwargs
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
prefix = prefix or self.source_prefix
|
prefix = prefix or self.source_prefix
|
||||||
|
|
||||||
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
|
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
|
||||||
|
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||||
inputs = inputs.to(self.model.device)
|
inputs = inputs.to(self.model.device)
|
||||||
prompt_length = len(inputs["input_ids"][0])
|
prompt_length = len(inputs["input_ids"][0])
|
||||||
|
|
||||||
@ -71,7 +79,11 @@ class ChatModel:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def chat(
|
def chat(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
self,
|
||||||
|
query: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
**input_kwargs
|
||||||
) -> Tuple[str, Tuple[int, int]]:
|
) -> Tuple[str, Tuple[int, int]]:
|
||||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||||
generation_output = self.model.generate(**gen_kwargs)
|
generation_output = self.model.generate(**gen_kwargs)
|
||||||
@ -82,7 +94,11 @@ class ChatModel:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
self,
|
||||||
|
query: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
**input_kwargs
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|||||||
@ -1,40 +1,50 @@
|
|||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
EXT2TYPE = {
|
||||||
|
"csv": "csv",
|
||||||
|
"json": "json",
|
||||||
|
"jsonl": "json",
|
||||||
|
"txt": "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
|
if file_sha1 is None:
|
||||||
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(data_files) != 1:
|
||||||
|
logger.warning("Checksum failed: too many files.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(data_files[0], "rb") as f:
|
||||||
|
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||||
|
if sha1 != file_sha1:
|
||||||
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments
|
data_args: "DataArguments"
|
||||||
) -> Dataset:
|
) -> "Dataset":
|
||||||
|
|
||||||
def checksum(file_path, hash):
|
|
||||||
with open(file_path, "rb") as datafile:
|
|
||||||
binary_data = datafile.read()
|
|
||||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
|
||||||
if sha1 != hash:
|
|
||||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
|
||||||
|
|
||||||
ext2type = {
|
|
||||||
"csv": "csv",
|
|
||||||
"json": "json",
|
|
||||||
"jsonl": "json",
|
|
||||||
"txt": "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
max_samples = data_args.max_samples
|
max_samples = data_args.max_samples
|
||||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
all_datasets: List["Dataset"] = [] # support multiple datasets
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
|
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
|
|
||||||
if dataset_attr.load_from == "hf_hub":
|
if dataset_attr.load_from == "hf_hub":
|
||||||
@ -47,60 +57,56 @@ def get_dataset(
|
|||||||
data_path = None
|
data_path = None
|
||||||
data_files: List[str] = []
|
data_files: List[str] = []
|
||||||
|
|
||||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
|
||||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||||
|
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
|
||||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
raise ValueError("File not found.")
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||||
|
checksum(data_files, dataset_attr.dataset_sha1)
|
||||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
|
||||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
|
||||||
else:
|
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
raw_datasets = load_dataset(
|
dataset = load_dataset(
|
||||||
data_path,
|
data_path,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
|
split=data_args.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
|
streaming=data_args.streaming,
|
||||||
use_auth_token=True if model_args.use_auth_token else None
|
use_auth_token=True if model_args.use_auth_token else None
|
||||||
)
|
)
|
||||||
dataset = raw_datasets[data_args.split]
|
|
||||||
|
|
||||||
if max_samples is not None:
|
if max_samples is not None:
|
||||||
max_samples_temp = min(len(dataset), max_samples)
|
max_samples_temp = min(len(dataset), max_samples)
|
||||||
dataset = dataset.select(range(max_samples_temp))
|
dataset = dataset.select(range(max_samples_temp))
|
||||||
|
|
||||||
dummy_data = [None] * len(dataset)
|
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||||
for column_name, target_name in [
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
("prompt_column", "prompt"),
|
|
||||||
("query_column", "query"),
|
if dataset_attr.source_prefix: # add prefix
|
||||||
("response_column", "response"),
|
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
|
||||||
("history_column", "history")
|
|
||||||
]: # every dataset will have 4 columns same as each other
|
|
||||||
if getattr(dataset_attr, column_name) != target_name:
|
|
||||||
if getattr(dataset_attr, column_name):
|
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
|
||||||
else: # None or empty string
|
|
||||||
dataset = dataset.add_column(target_name, dummy_data)
|
|
||||||
dataset = dataset.add_column("prefix", prefix_data)
|
|
||||||
all_datasets.append(dataset)
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
if len(data_args.dataset_list) == 1:
|
||||||
all_datasets = all_datasets[0]
|
return all_datasets[0]
|
||||||
|
elif data_args.mix_strategy == "concat":
|
||||||
|
if data_args.streaming:
|
||||||
|
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
return concatenate_datasets(all_datasets)
|
||||||
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
|
if not data_args.streaming:
|
||||||
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||||
|
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
|
||||||
else:
|
else:
|
||||||
all_datasets = concatenate_datasets(all_datasets)
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|
||||||
return all_datasets
|
|
||||||
|
|||||||
@ -1,65 +1,63 @@
|
|||||||
from typing import Literal
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from transformers import Seq2SeqTrainingArguments
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.hparams import DataArguments
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
def preprocess_dataset(
|
||||||
dataset: Dataset,
|
dataset: "Dataset",
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: "PreTrainedTokenizer",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||||
) -> Dataset:
|
) -> "Dataset":
|
||||||
|
column_names = list(dataset.column_names or [])
|
||||||
|
template = get_template(data_args.template)
|
||||||
|
|
||||||
column_names = list(dataset.column_names)
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
prompt_template = get_template(data_args.prompt_template)
|
|
||||||
|
|
||||||
# support question with a single answer or multiple answers
|
|
||||||
def get_dialog(examples):
|
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
if examples["prompt"][i] and examples["response"][i]:
|
query, response = examples["prompt"][i], examples["response"][i]
|
||||||
query, answer = examples["prompt"][i], examples["response"][i]
|
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
history = history if "history" in examples and examples["history"][i] else []
|
||||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
|
||||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
yield query, response, history, prefix
|
||||||
yield dialog
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples):
|
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
|
||||||
concatenated_ids = list(chain(*text_ids))
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_ids)
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.max_source_length - 1
|
block_size = data_args.max_source_length
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of max_source_length
|
# split by chunks of max_source_length
|
||||||
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
result = {
|
||||||
for i in range(0, total_length, block_size)]
|
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||||
return {
|
for k, t in concatenated_examples.items()
|
||||||
"input_ids": result,
|
|
||||||
"labels": result.copy()
|
|
||||||
}
|
}
|
||||||
|
result["labels"] = result["input_ids"].copy()
|
||||||
|
return result
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples):
|
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
# for input with history, we build multiple input-label pairs just like:
|
# for input with history, we build multiple input-label pairs just like:
|
||||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
max_length = data_args.max_source_length + data_args.max_target_length
|
max_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
for dialog in get_dialog(examples):
|
for query, response, history, prefix in construct_example(examples):
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
|
|
||||||
for i in range(len(dialog) // 2):
|
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
|
||||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
|
||||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
@ -73,19 +71,20 @@ def preprocess_dataset(
|
|||||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples):
|
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
for dialog in get_dialog(examples):
|
for query, response, history, prefix in construct_example(examples):
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
@ -93,6 +92,7 @@ def preprocess_dataset(
|
|||||||
target_ids = target_ids[:data_args.max_target_length]
|
target_ids = target_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
model_inputs["input_ids"].append(source_ids)
|
model_inputs["input_ids"].append(source_ids)
|
||||||
|
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||||
model_inputs["labels"].append(target_ids)
|
model_inputs["labels"].append(target_ids)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
@ -100,12 +100,12 @@ def preprocess_dataset(
|
|||||||
def preprocess_pairwise_dataset(examples):
|
def preprocess_pairwise_dataset(examples):
|
||||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||||
for dialog in get_dialog(examples):
|
for query, response, history, prefix in construct_example(examples):
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
|
||||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
@ -141,34 +141,44 @@ def preprocess_dataset(
|
|||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
|
dataset = dataset.filter(lambda example: example["prompt"])
|
||||||
preprocess_function = preprocess_pretrain_dataset
|
preprocess_function = preprocess_pretrain_dataset
|
||||||
elif stage == "sft":
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if not training_args.predict_with_generate:
|
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||||
preprocess_function = preprocess_supervised_dataset
|
preprocess_function = preprocess_supervised_dataset
|
||||||
else:
|
|
||||||
preprocess_function = preprocess_unsupervised_dataset
|
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
|
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||||
preprocess_function = preprocess_pairwise_dataset
|
preprocess_function = preprocess_pairwise_dataset
|
||||||
elif stage == "ppo":
|
else:
|
||||||
|
dataset = dataset.filter(lambda example: example["prompt"])
|
||||||
preprocess_function = preprocess_unsupervised_dataset
|
preprocess_function = preprocess_unsupervised_dataset
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on dataset"
|
||||||
|
)
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
**kwargs
|
||||||
desc="Running tokenizer on dataset"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||||
|
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
print_unsupervised_dataset_example(dataset[0])
|
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
print_supervised_dataset_example(dataset[0])
|
print_supervised_dataset_example(next(iter(dataset)))
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
print_pairwise_dataset_example(dataset[0])
|
print_pairwise_dataset_example(next(iter(dataset)))
|
||||||
elif stage == "ppo":
|
elif stage == "ppo":
|
||||||
print_unsupervised_dataset_example(dataset[0])
|
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
from typing import Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
from datasets import Dataset
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
|
||||||
dataset: Dataset, dev_ratio: float, do_train: bool
|
|
||||||
) -> Dict[str, Dataset]:
|
|
||||||
# Split the dataset
|
|
||||||
if do_train:
|
if do_train:
|
||||||
if dev_ratio > 1e-6:
|
if dev_ratio > 1e-6: # Split the dataset
|
||||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,16 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import (
|
from transformers import TrainerCallback
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
if TYPE_CHECKING:
|
||||||
TrainerState,
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
TrainingArguments
|
|
||||||
)
|
|
||||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
||||||
from transformers.training_args import TrainingArguments
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
@ -20,13 +17,13 @@ class LogCallback(TrainerCallback):
|
|||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.tracker = {}
|
self.tracker = {}
|
||||||
|
|
||||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of training.
|
Event called at the beginning of training.
|
||||||
"""
|
"""
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||||
might take several inputs.
|
might take several inputs.
|
||||||
@ -35,7 +32,7 @@ class LogCallback(TrainerCallback):
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of an substep during gradient accumulation.
|
Event called at the end of an substep during gradient accumulation.
|
||||||
"""
|
"""
|
||||||
@ -43,7 +40,7 @@ class LogCallback(TrainerCallback):
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.generation.utils import LogitsProcessorList
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
r"""
|
r"""
|
||||||
@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList:
|
|||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
r"""
|
||||||
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
|
"""
|
||||||
trainable_params, all_param = 0, 0
|
trainable_params, all_param = 0, 0
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
num_params = param.numel()
|
num_params = param.numel()
|
||||||
# if using DS Zero 3 and the weights are initialized empty
|
# if using DS Zero 3 and the weights are initialized empty
|
||||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||||
num_params = param.ds_numel
|
num_params = param.ds_numel
|
||||||
|
|
||||||
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||||
|
if param.__class__.__name__ == "Params4bit":
|
||||||
|
num_params = num_params * 2
|
||||||
|
|
||||||
all_param += num_params
|
all_param += num_params
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
trainable_params += num_params
|
trainable_params += num_params
|
||||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
|
||||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: PreTrainedModel,
|
model: "PreTrainedModel",
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
) -> PreTrainedModel:
|
) -> "PreTrainedModel":
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
@ -84,6 +94,9 @@ def prepare_model_for_training(
|
|||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||||
|
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
|
||||||
|
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
|
||||||
|
|
||||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||||
input_dtype = output_layer.weight.dtype
|
input_dtype = output_layer.weight.dtype
|
||||||
|
|
||||||
@ -92,11 +105,8 @@ def prepare_model_for_training(
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||||
|
|
||||||
new_output_layer = CastOutputToFloat(output_layer)
|
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
||||||
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
|
|
||||||
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
|
|
||||||
setattr(new_output_layer, "weight", output_layer.weight)
|
|
||||||
setattr(model, output_layer_name, new_output_layer)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||||
from transformers.modeling_utils import load_sharded_checkpoint
|
from transformers.modeling_utils import load_sharded_checkpoint
|
||||||
@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
|
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
state_dict = model.state_dict()
|
state_dict: Dict[str, torch.Tensor] = model.state_dict()
|
||||||
filtered_state_dict = {}
|
filtered_state_dict = {}
|
||||||
|
|
||||||
for k, v in model.named_parameters():
|
for k, v in model.named_parameters():
|
||||||
if (not trainable_only) or v.requires_grad:
|
if v.requires_grad:
|
||||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||||
|
|
||||||
return filtered_state_dict
|
return filtered_state_dict
|
||||||
|
|||||||
@ -11,37 +11,46 @@ class Template:
|
|||||||
use_history: bool
|
use_history: bool
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
self,
|
||||||
|
query: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
prefix: Optional[str] = "",
|
||||||
|
eos_token: Optional[str] = "</s>"
|
||||||
) -> str:
|
) -> str:
|
||||||
r"""
|
r"""
|
||||||
Returns a string containing prompt without response.
|
Returns a string containing prompt without response.
|
||||||
"""
|
"""
|
||||||
return "".join(self._format_example(query, history, prefix))
|
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
|
||||||
|
|
||||||
def get_dialog(
|
def get_dialog(
|
||||||
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
self,
|
||||||
) -> List[str]:
|
query: str,
|
||||||
|
resp: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
prefix: Optional[str] = ""
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
Returns a list containing prompt-response pairs.
|
||||||
"""
|
"""
|
||||||
return self._format_example(query, history, prefix) + [resp]
|
result = self._format_example(query, history, prefix)
|
||||||
|
result[-1][-1] = resp
|
||||||
|
return result
|
||||||
|
|
||||||
def _format_example(
|
def _format_example(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
self,
|
||||||
) -> List[str]:
|
query: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
prefix: Optional[str] = ""
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
prefix = prefix or self.prefix # use prefix if provided
|
prefix = prefix or self.prefix # use prefix if provided
|
||||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||||
history = history if (history and self.use_history) else []
|
history = history if (history and self.use_history) else []
|
||||||
history = history + [(query, "<dummy>")]
|
history = history + [(query, "")]
|
||||||
convs = []
|
convs = [
|
||||||
for turn_idx, (user_query, bot_resp) in enumerate(history):
|
[(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
|
||||||
if turn_idx == 0:
|
for turn_idx, (query_i, resp_i) in enumerate(history)
|
||||||
convs.append(prefix + self.prompt.format(query=user_query))
|
]
|
||||||
convs.append(bot_resp)
|
return convs
|
||||||
else:
|
|
||||||
convs.append(self.sep + self.prompt.format(query=user_query))
|
|
||||||
convs.append(bot_resp)
|
|
||||||
return convs[:-1] # drop last
|
|
||||||
|
|
||||||
|
|
||||||
templates: Dict[str, Template] = {}
|
templates: Dict[str, Template] = {}
|
||||||
@ -103,7 +112,7 @@ register_template(
|
|||||||
"explain why instead of answering something not correct. "
|
"explain why instead of answering something not correct. "
|
||||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||||
prompt=" [INST] {query} [/INST] ",
|
prompt=" [INST] {query} [/INST] ",
|
||||||
sep="</s>",
|
sep="",
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,7 +140,7 @@ register_template(
|
|||||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||||
prompt="USER: {query} ASSISTANT: ",
|
prompt="USER: {query} ASSISTANT: ",
|
||||||
sep="</s>",
|
sep="",
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,7 +225,7 @@ register_template(
|
|||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix="",
|
prefix="",
|
||||||
prompt="<reserved_102>{query}<reserved_103>",
|
prompt="<reserved_102>{query}<reserved_103>",
|
||||||
sep="</s>",
|
sep="",
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@ -16,10 +16,10 @@ class DatasetAttr:
|
|||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.prompt_column = "instruction"
|
self.prompt = "instruction"
|
||||||
self.query_column = "input"
|
self.query = "input"
|
||||||
self.response_column = "output"
|
self.response = "output"
|
||||||
self.history_column = None
|
self.history = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -27,8 +27,11 @@ class DataArguments:
|
|||||||
"""
|
"""
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
"""
|
"""
|
||||||
|
template: str = field(
|
||||||
|
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||||
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: Optional[str] = field(
|
||||||
default="alpaca_zh",
|
default="alpaca_en",
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: Optional[str] = field(
|
||||||
@ -39,6 +42,18 @@ class DataArguments:
|
|||||||
default="train",
|
default="train",
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||||
)
|
)
|
||||||
|
streaming: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Enable streaming mode."}
|
||||||
|
)
|
||||||
|
buffer_size: Optional[int] = field(
|
||||||
|
default=16384,
|
||||||
|
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||||
|
)
|
||||||
|
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||||
|
default="concat",
|
||||||
|
metadata={"help": "Strategy to use in dataset mixing."}
|
||||||
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||||
@ -75,10 +90,6 @@ class DataArguments:
|
|||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||||
)
|
)
|
||||||
prompt_template: Optional[str] = field(
|
|
||||||
default="default",
|
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_for_training(self): # support mixing multiple datasets
|
def init_for_training(self): # support mixing multiple datasets
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||||
@ -111,9 +122,9 @@ class DataArguments:
|
|||||||
dataset_attr.source_prefix = prefix_list[i]
|
dataset_attr.source_prefix = prefix_list[i]
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||||
|
|
||||||
self.dataset_list.append(dataset_attr)
|
self.dataset_list.append(dataset_attr)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GeneralArguments:
|
class GeneralArguments:
|
||||||
"""
|
"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which stage we are going to perform.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftModel,
|
PeftModel,
|
||||||
TaskType,
|
TaskType,
|
||||||
@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
|||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.save_and_load import load_trainable_params
|
from llmtuner.extras.save_and_load import load_trainable_params
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_adapter(
|
def init_adapter(
|
||||||
model: PreTrainedModel,
|
model: "PreTrainedModel",
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
is_mergeable: bool
|
is_mergeable: bool
|
||||||
) -> PreTrainedModel:
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Initializes the adapters.
|
Initializes the adapters.
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@ -16,11 +16,13 @@ from transformers.tokenization_utils import PreTrainedTokenizerBase
|
|||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
|
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
|
||||||
from llmtuner.tuner.core.adapter import init_adapter
|
from llmtuner.tuner.core.adapter import init_adapter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -33,8 +35,8 @@ require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
|||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
||||||
@ -141,6 +143,9 @@ def load_model_and_tokenizer(
|
|||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False) # fix all model params
|
||||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||||
|
|
||||||
print_trainable_params(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
|
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
|
))
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|||||||
@ -19,20 +19,39 @@ from llmtuner.hparams import (
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||||
|
if args is not None:
|
||||||
|
return parser.parse_dict(args)
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
|
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
return parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_train_args(
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||||
|
parser = HfArgumentParser((
|
||||||
|
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
|
||||||
|
))
|
||||||
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_infer_args(
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||||
|
parser = HfArgumentParser((
|
||||||
|
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
))
|
||||||
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(
|
def get_train_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||||
|
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
|
||||||
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
|
|
||||||
|
|
||||||
if args is not None:
|
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
@ -73,13 +92,22 @@ def get_train_args(
|
|||||||
if training_args.do_train and (not training_args.fp16):
|
if training_args.do_train and (not training_args.fp16):
|
||||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||||
|
|
||||||
if data_args.prompt_template == "default":
|
if (
|
||||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
training_args.local_rank != -1
|
||||||
|
and training_args.ddp_find_unused_parameters is None
|
||||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
and finetuning_args.finetuning_type == "lora"
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
):
|
||||||
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
|
if data_args.max_samples is not None and data_args.streaming:
|
||||||
|
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
|
||||||
|
data_args.streaming = False
|
||||||
|
|
||||||
|
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||||
|
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||||
|
data_args.dev_ratio = 0
|
||||||
|
|
||||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
@ -106,17 +134,7 @@ def get_train_args(
|
|||||||
def get_infer_args(
|
def get_infer_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||||
|
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||||
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
|
|
||||||
|
|
||||||
if args is not None:
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||||
"Quantization is only compatible with the LoRA method."
|
"Quantization is only compatible with the LoRA method."
|
||||||
@ -128,7 +146,4 @@ def get_infer_args(
|
|||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||||
"Quantized model only accepts a single checkpoint."
|
"Quantized model only accepts a single checkpoint."
|
||||||
|
|
||||||
if data_args.prompt_template == "default":
|
|
||||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Optional
|
from typing import TYPE_CHECKING, Dict, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||||
from llmtuner.hparams import FinetuningArguments
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@ -21,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self._remove_log()
|
self._remove_log()
|
||||||
@ -42,31 +45,35 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
model = unwrap_model(self.model)
|
model = unwrap_model(self.model)
|
||||||
|
state_dict = state_dict or get_state_dict(model)
|
||||||
|
|
||||||
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
if isinstance(model, PreTrainedModelWrapper):
|
||||||
backbone_model = getattr(model, "pretrained_model")
|
model_params, v_head_params = {}, {}
|
||||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
for name in state_dict.keys():
|
||||||
else:
|
if name.startswith("pretrained_model."):
|
||||||
backbone_model = model
|
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
|
||||||
|
elif name.startswith("v_head."):
|
||||||
|
v_head_params[name.replace("v_head.", "")] = state_dict[name]
|
||||||
|
|
||||||
if isinstance(backbone_model, PeftModel): # LoRA tuning
|
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
state_dict = model_params
|
||||||
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
|
model = model.pretrained_model
|
||||||
backbone_model.config.use_cache = True
|
|
||||||
backbone_model.save_pretrained(
|
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||||
output_dir,
|
model.config.use_cache = True
|
||||||
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
|
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||||
safe_serialization=self.args.save_safetensors
|
model.config.use_cache = False
|
||||||
)
|
|
||||||
backbone_model.config.use_cache = False
|
|
||||||
if self.tokenizer is not None:
|
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
|
||||||
else:
|
else:
|
||||||
logger.warning("No model to save.")
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||||
f.write(self.args.to_json_string() + "\n")
|
f.write(self.args.to_json_string() + "\n")
|
||||||
|
|
||||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||||
|
|
||||||
def _load_best_model(self):
|
def _load_best_model(self):
|
||||||
@ -76,16 +83,15 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||||
|
|
||||||
model = unwrap_model(self.model)
|
model = unwrap_model(self.model)
|
||||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
|
||||||
|
|
||||||
if isinstance(backbone_model, PeftModel):
|
if isinstance(model, PreTrainedModelWrapper):
|
||||||
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
|
model.v_head.load_state_dict(torch.load(
|
||||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||||
model.v_head.load_state_dict({
|
))
|
||||||
"summary.weight": getattr(model, "reward_head_weight"),
|
model = model.pretrained_model
|
||||||
"summary.bias": getattr(model, "reward_head_bias")
|
|
||||||
})
|
if isinstance(model, PeftModel):
|
||||||
|
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||||
else: # freeze/full-tuning
|
else: # freeze/full-tuning
|
||||||
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
|
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||||
|
|||||||
@ -2,21 +2,25 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainerState, TrainerControl
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
||||||
from llmtuner.hparams import FinetuningArguments
|
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -27,9 +31,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: List[LogCallback],
|
callbacks: List["LogCallback"],
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Dict, List, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
|
||||||
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if target == "reward": # save default head temporarily
|
||||||
valuehead_state_dict = model.v_head.state_dict()
|
valuehead_state_dict = model.v_head.state_dict()
|
||||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||||
@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
|||||||
|
|
||||||
|
|
||||||
def cast_layernorm_dtype(
|
def cast_layernorm_dtype(
|
||||||
model: AutoModelForCausalLMWithValueHead,
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
layer_norm_state_dict = {}
|
layer_norm_state_dict = {}
|
||||||
|
|
||||||
|
|||||||
@ -2,26 +2,30 @@
|
|||||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from trl import PPOConfig
|
from trl import PPOConfig
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import DataCollatorForSeq2Seq
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||||
|
|||||||
@ -1,24 +1,27 @@
|
|||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||||
|
|||||||
@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
|||||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||||
the last n examples represent rejected examples.
|
the last n examples represent rejected examples.
|
||||||
"""
|
"""
|
||||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
features = [
|
||||||
|
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
||||||
|
for key in ("accept_ids", "reject_ids") for feature in features
|
||||||
|
]
|
||||||
return super().__call__(features)
|
return super().__call__(features)
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
from transformers.trainer import PredictionOutput
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.trainer import PredictionOutput
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel,
|
model: "PreTrainedModel",
|
||||||
inputs: Dict[str, torch.Tensor],
|
inputs: Dict[str, torch.Tensor],
|
||||||
return_outputs: Optional[bool] = False
|
return_outputs: Optional[bool] = False
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
predict_results: PredictionOutput
|
predict_results: "PredictionOutput"
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|||||||
@ -2,25 +2,27 @@
|
|||||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_rm(
|
def run_rm(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComputeMetrics:
|
class ComputeMetrics:
|
||||||
@ -16,7 +18,7 @@ class ComputeMetrics:
|
|||||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizer
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
|
||||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@ -3,13 +3,15 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
from transformers.trainer import PredictionOutput
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -81,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
predict_results: PredictionOutput
|
predict_results: "PredictionOutput"
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|||||||
@ -1,25 +1,28 @@
|
|||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class WebChatModel(ChatModel):
|
|||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
prompt_template=template,
|
template=template,
|
||||||
source_prefix=source_prefix
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
super().__init__(*get_infer_args(args))
|
super().__init__(*get_infer_args(args))
|
||||||
|
|||||||
@ -111,7 +111,7 @@ class Runner:
|
|||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
prompt_template=template,
|
template=template,
|
||||||
source_prefix=source_prefix,
|
source_prefix=source_prefix,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=",".join(dataset),
|
dataset=",".join(dataset),
|
||||||
@ -201,7 +201,7 @@ class Runner:
|
|||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
prompt_template=template,
|
template=template,
|
||||||
source_prefix=source_prefix,
|
source_prefix=source_prefix,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=",".join(dataset),
|
dataset=",".join(dataset),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user