diff --git a/README.md b/README.md index 94458c80..852ae132 100644 --- a/README.md +++ b/README.md @@ -12,15 +12,15 @@ ## Changelog -[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model. +[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model. [23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. -[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model. +[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. -[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model. +[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model. [23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model. @@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_train \ --dataset wiki_demo \ + --template default \ --finetuning_type lora \ --output_dir path_to_pt_checkpoint \ --overwrite_cache \ @@ -175,6 +176,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_train \ --dataset alpaca_gpt4_en \ + --template default \ --finetuning_type lora \ --output_dir path_to_sft_checkpoint \ --overwrite_cache \ @@ -197,6 +199,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_train \ --dataset comparison_gpt4_en \ + --template default \ --finetuning_type lora \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ @@ -220,6 +223,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_train \ --dataset alpaca_gpt4_en \ + --template default \ --finetuning_type lora \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ @@ -278,6 +282,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_eval \ --dataset alpaca_gpt4_en \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_eval_result \ @@ -296,6 +301,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_predict \ --dataset alpaca_gpt4_en \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_predict_result \ @@ -311,6 +317,7 @@ If you want to predict the samples with empty responses, please kindly fill the ```bash python src/api_demo.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint ``` @@ -322,6 +329,7 @@ Visit `http://localhost:8000/docs` for API documentation. ```bash python src/cli_demo.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint ``` @@ -331,6 +339,7 @@ python src/cli_demo.py \ ```bash python src/web_demo.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint ``` @@ -340,6 +349,7 @@ python src/web_demo.py \ ```bash python src/export_model.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_export diff --git a/README_zh.md b/README_zh.md index 4d882b62..4ca9194d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,15 +12,15 @@ ## 更新日志 -[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--prompt_template llama2` 参数。 +[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。 [23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 -[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--prompt_template baichuan` 参数。 +[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。 [23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 -[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--prompt_template intern` 参数。 +[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。 [23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。 @@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_your_model \ --do_train \ --dataset wiki_demo \ + --template default \ --finetuning_type lora \ --output_dir path_to_pt_checkpoint \ --overwrite_cache \ @@ -174,7 +175,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage sft \ --model_name_or_path path_to_your_model \ --do_train \ - --dataset alpaca_gpt4_en \ + --dataset alpaca_gpt4_zh \ + --template default \ --finetuning_type lora \ --output_dir path_to_sft_checkpoint \ --overwrite_cache \ @@ -196,7 +198,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage rm \ --model_name_or_path path_to_your_model \ --do_train \ - --dataset comparison_gpt4_en \ + --dataset comparison_gpt4_zh \ + --template default \ --finetuning_type lora \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ @@ -219,7 +222,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage ppo \ --model_name_or_path path_to_your_model \ --do_train \ - --dataset alpaca_gpt4_en \ + --dataset alpaca_gpt4_zh \ + --template default \ --finetuning_type lora \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ @@ -277,7 +281,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage sft \ --model_name_or_path path_to_your_model \ --do_eval \ - --dataset alpaca_gpt4_en \ + --dataset alpaca_gpt4_zh \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_eval_result \ @@ -295,7 +300,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage sft \ --model_name_or_path path_to_your_model \ --do_predict \ - --dataset alpaca_gpt4_en \ + --dataset alpaca_gpt4_zh \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_predict_result \ @@ -311,6 +317,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ ```bash python src/api_demo.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint ``` @@ -322,6 +329,7 @@ python src/api_demo.py \ ```bash python src/cli_demo.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint ``` @@ -331,6 +339,7 @@ python src/cli_demo.py \ ```bash python src/web_demo.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint ``` @@ -340,6 +349,7 @@ python src/web_demo.py \ ```bash python src/export_model.py \ --model_name_or_path path_to_your_model \ + --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_export diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index a4b46dd6..773dfc4e 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -1,42 +1,50 @@ import torch -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from threading import Thread from transformers import TextIteratorStreamer from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.template import get_template -from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from llmtuner.tuner import load_model_and_tokenizer +if TYPE_CHECKING: + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + class ChatModel: def __init__( self, - model_args: ModelArguments, - data_args: DataArguments, - finetuning_args: FinetuningArguments, - generating_args: GeneratingArguments + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments" ) -> None: self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) if torch.cuda.device_count() > 1: - from accelerate import dispatch_model, infer_auto_device_map - device_map = infer_auto_device_map(self.model) + from accelerate import dispatch_model + from accelerate.utils import infer_auto_device_map, get_balanced_memory + device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model)) self.model = dispatch_model(self.model, device_map) else: self.model = self.model.cuda() - self.template = get_template(data_args.prompt_template) - self.source_prefix = data_args.source_prefix or "" + self.template = get_template(data_args.template) + self.source_prefix = data_args.source_prefix self.generating_args = generating_args def process_args( - self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = None, + **input_kwargs ) -> Tuple[Dict[str, Any], int]: prefix = prefix or self.source_prefix - inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt") + prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token) + inputs = self.tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.model.device) prompt_length = len(inputs["input_ids"][0]) @@ -71,7 +79,11 @@ class ChatModel: @torch.inference_mode() def chat( - self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = None, + **input_kwargs ) -> Tuple[str, Tuple[int, int]]: gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) generation_output = self.model.generate(**gen_kwargs) @@ -82,7 +94,11 @@ class ChatModel: @torch.inference_mode() def stream_chat( - self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = None, + **input_kwargs ) -> Generator[str, None, None]: gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 005cbee5..8d0bfe59 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -1,40 +1,50 @@ import os import hashlib -from typing import List +from typing import TYPE_CHECKING, List, Optional -from datasets import Dataset, concatenate_datasets, load_dataset +from datasets import concatenate_datasets, interleave_datasets, load_dataset from llmtuner.extras.logging import get_logger -from llmtuner.hparams import ModelArguments, DataArguments + +if TYPE_CHECKING: + from datasets import Dataset + from llmtuner.hparams import ModelArguments, DataArguments logger = get_logger(__name__) +EXT2TYPE = { + "csv": "csv", + "json": "json", + "jsonl": "json", + "txt": "text" +} + + +def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: + if file_sha1 is None: + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") + return + + if len(data_files) != 1: + logger.warning("Checksum failed: too many files.") + return + + with open(data_files[0], "rb") as f: + sha1 = hashlib.sha1(f.read()).hexdigest() + if sha1 != file_sha1: + logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) + + def get_dataset( - model_args: ModelArguments, - data_args: DataArguments -) -> Dataset: - - def checksum(file_path, hash): - with open(file_path, "rb") as datafile: - binary_data = datafile.read() - sha1 = hashlib.sha1(binary_data).hexdigest() - if sha1 != hash: - logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) - - ext2type = { - "csv": "csv", - "json": "json", - "jsonl": "json", - "txt": "text" - } - + model_args: "ModelArguments", + data_args: "DataArguments" +) -> "Dataset": max_samples = data_args.max_samples - all_datasets: List[Dataset] = [] # support multiple datasets + all_datasets: List["Dataset"] = [] # support multiple datasets for dataset_attr in data_args.dataset_list: - logger.info("Loading dataset {}...".format(dataset_attr)) if dataset_attr.load_from == "hf_hub": @@ -47,60 +57,56 @@ def get_dataset( data_path = None data_files: List[str] = [] - if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) - if data_path is None: - data_path = ext2type.get(data_files[0].split(".")[-1], None) + data_path = EXT2TYPE.get(file_name.split(".")[-1], None) else: - assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." - elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match." + elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) - data_path = ext2type.get(data_files[0].split(".")[-1], None) + data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) else: raise ValueError("File not found.") assert data_path, "File extension must be txt, csv, json or jsonl." - - if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None: - checksum(data_files[0], dataset_attr.dataset_sha1) - else: - logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.") + checksum(data_files, dataset_attr.dataset_sha1) else: raise NotImplementedError - raw_datasets = load_dataset( + dataset = load_dataset( data_path, data_files=data_files, + split=data_args.split, cache_dir=model_args.cache_dir, + streaming=data_args.streaming, use_auth_token=True if model_args.use_auth_token else None ) - dataset = raw_datasets[data_args.split] if max_samples is not None: max_samples_temp = min(len(dataset), max_samples) dataset = dataset.select(range(max_samples_temp)) - dummy_data = [None] * len(dataset) - prefix_data = [dataset_attr.source_prefix] * len(dataset) - for column_name, target_name in [ - ("prompt_column", "prompt"), - ("query_column", "query"), - ("response_column", "response"), - ("history_column", "history") - ]: # every dataset will have 4 columns same as each other - if getattr(dataset_attr, column_name) != target_name: - if getattr(dataset_attr, column_name): - dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) - else: # None or empty string - dataset = dataset.add_column(target_name, dummy_data) - dataset = dataset.add_column("prefix", prefix_data) + for column_name in ["prompt", "query", "response", "history"]: # align datasets + if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: + dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) + + if dataset_attr.source_prefix: # add prefix + dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}) + all_datasets.append(dataset) if len(data_args.dataset_list) == 1: - all_datasets = all_datasets[0] + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning("The samples between different datasets will not be mixed in streaming mode.") + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" + return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy) else: - all_datasets = concatenate_datasets(all_datasets) - - return all_datasets + raise ValueError("Unknown mixing strategy.") diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index f743e27e..93c854e0 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -1,65 +1,63 @@ -from typing import Literal +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal from itertools import chain -from transformers import Seq2SeqTrainingArguments -from transformers.tokenization_utils import PreTrainedTokenizer - -from datasets import Dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.template import get_template -from llmtuner.hparams import DataArguments + +if TYPE_CHECKING: + from datasets import Dataset + from transformers import Seq2SeqTrainingArguments + from transformers.tokenization_utils import PreTrainedTokenizer + from llmtuner.hparams import DataArguments def preprocess_dataset( - dataset: Dataset, - tokenizer: PreTrainedTokenizer, - data_args: DataArguments, - training_args: Seq2SeqTrainingArguments, + dataset: "Dataset", + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"] -) -> Dataset: +) -> "Dataset": + column_names = list(dataset.column_names or []) + template = get_template(data_args.template) - column_names = list(dataset.column_names) - prompt_template = get_template(data_args.prompt_template) - - # support question with a single answer or multiple answers - def get_dialog(examples): + def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: for i in range(len(examples["prompt"])): - if examples["prompt"][i] and examples["response"][i]: - query, answer = examples["prompt"][i], examples["response"][i] - query = query + "\n" + examples["query"][i] if examples["query"][i] else query - prefix = examples["prefix"][i] if examples["prefix"][i] else "" - dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) - yield dialog + query, response = examples["prompt"][i], examples["response"][i] + query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query + history = history if "history" in examples and examples["history"][i] else [] + prefix = prefix if "prefix" in examples and examples["prefix"][i] else "" + yield query, response, history, prefix - def preprocess_pretrain_dataset(examples): + def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build grouped texts with format ` X1 X2 X3 ...` (without ) - text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] - concatenated_ids = list(chain(*text_ids)) - total_length = len(concatenated_ids) - block_size = data_args.max_source_length - 1 + tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.max_source_length # we drop the small remainder, and if the total_length < block_size, we exclude this batch total_length = (total_length // block_size) * block_size # split by chunks of max_source_length - result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] - for i in range(0, total_length, block_size)] - return { - "input_ids": result, - "labels": result.copy() + result = { + k: [t[i: i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() } + result["labels"] = result["input_ids"].copy() + return result - def preprocess_supervised_dataset(examples): + def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for input with history, we build multiple input-label pairs just like: # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 - model_inputs = {"input_ids": [], "labels": []} + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} max_length = data_args.max_source_length + data_args.max_target_length - for dialog in get_dialog(examples): + for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for i in range(len(dialog) // 2): - source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) - target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) + for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)): + source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0)) + target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -73,19 +71,20 @@ def preprocess_dataset( labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) return model_inputs - def preprocess_unsupervised_dataset(examples): + def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build inputs with format ` X` and labels with format ` Y` - model_inputs = {"input_ids": [], "labels": []} + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - for dialog in get_dialog(examples): - prompt, answer = "".join(dialog[:-1]), dialog[-1] + for query, response, history, prefix in construct_example(examples): + prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - target_ids = tokenizer.encode(text=answer, add_special_tokens=True) + target_ids = tokenizer.encode(text=response, add_special_tokens=True) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -93,6 +92,7 @@ def preprocess_dataset( target_ids = target_ids[:data_args.max_target_length] model_inputs["input_ids"].append(source_ids) + model_inputs["attention_mask"].append([1] * len(source_ids)) model_inputs["labels"].append(target_ids) return model_inputs @@ -100,12 +100,12 @@ def preprocess_dataset( def preprocess_pairwise_dataset(examples): # build input pairs with format ` X Y1 ` and ` X Y2 ` model_inputs = {"accept_ids": [], "reject_ids": []} - for dialog in get_dialog(examples): - prompt, answer = "".join(dialog[:-1]), dialog[-1] + for query, response, history, prefix in construct_example(examples): + prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) - reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) + accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False) + reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -141,34 +141,44 @@ def preprocess_dataset( print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) if stage == "pt": + dataset = dataset.filter(lambda example: example["prompt"]) preprocess_function = preprocess_pretrain_dataset - elif stage == "sft": - if not training_args.predict_with_generate: - preprocess_function = preprocess_supervised_dataset - else: - preprocess_function = preprocess_unsupervised_dataset + elif stage == "sft" and not training_args.predict_with_generate: + dataset = dataset.filter(lambda example: example["prompt"] and example["response"]) + preprocess_function = preprocess_supervised_dataset elif stage == "rm": + dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1) preprocess_function = preprocess_pairwise_dataset - elif stage == "ppo": + else: + dataset = dataset.filter(lambda example: example["prompt"]) preprocess_function = preprocess_unsupervised_dataset with training_args.main_process_first(desc="dataset map pre-processing"): + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset" + ) + dataset = dataset.map( preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, + batched=True, remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset" + **kwargs ) + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size) + if stage == "pt": - print_unsupervised_dataset_example(dataset[0]) + print_unsupervised_dataset_example(next(iter(dataset))) elif stage == "sft": - print_supervised_dataset_example(dataset[0]) + print_supervised_dataset_example(next(iter(dataset))) elif stage == "rm": - print_pairwise_dataset_example(dataset[0]) + print_pairwise_dataset_example(next(iter(dataset))) elif stage == "ppo": - print_unsupervised_dataset_example(dataset[0]) + print_unsupervised_dataset_example(next(iter(dataset))) return dataset diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/dsets/utils.py index 64436e70..31c48222 100644 --- a/src/llmtuner/dsets/utils.py +++ b/src/llmtuner/dsets/utils.py @@ -1,13 +1,12 @@ -from typing import Dict -from datasets import Dataset +from typing import TYPE_CHECKING, Dict + +if TYPE_CHECKING: + from datasets import Dataset -def split_dataset( - dataset: Dataset, dev_ratio: float, do_train: bool -) -> Dict[str, Dataset]: - # Split the dataset +def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]: if do_train: - if dev_ratio > 1e-6: + if dev_ratio > 1e-6: # Split the dataset dataset = dataset.train_test_split(test_size=dev_ratio) return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} else: diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 112178ba..9c45b31e 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -1,16 +1,13 @@ import os import json import time +from typing import TYPE_CHECKING from datetime import timedelta -from transformers import ( - TrainerCallback, - TrainerControl, - TrainerState, - TrainingArguments -) -from transformers.trainer_callback import TrainerControl, TrainerState -from transformers.training_args import TrainingArguments +from transformers import TrainerCallback + +if TYPE_CHECKING: + from transformers import TrainingArguments, TrainerState, TrainerControl class LogCallback(TrainerCallback): @@ -20,13 +17,13 @@ class LogCallback(TrainerCallback): self.start_time = time.time() self.tracker = {} - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the beginning of training. """ self.start_time = time.time() - def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the beginning of a training step. If using gradient accumulation, one training step might take several inputs. @@ -35,7 +32,7 @@ class LogCallback(TrainerCallback): control.should_epoch_stop = True control.should_training_stop = True - def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of an substep during gradient accumulation. """ @@ -43,7 +40,7 @@ class LogCallback(TrainerCallback): control.should_epoch_stop = True control.should_training_stop = True - def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: r""" Event called after logging the last logs. """ diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 9c4e165e..82e695dc 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,12 +1,14 @@ import torch -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple -from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import LogitsProcessorList from transformers.generation.logits_process import LogitsProcessor from llmtuner.extras.constants import LAYERNORM_NAMES +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + class AverageMeter: r""" @@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList: return logits_processor -def print_trainable_params(model: torch.nn.Module) -> None: +def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ trainable_params, all_param = 0, 0 for param in model.parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 + if param.__class__.__name__ == "Params4bit": + num_params = num_params * 2 + all_param += num_params if param.requires_grad: trainable_params += num_params - print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( - trainable_params, all_param, 100 * trainable_params / all_param)) + + return trainable_params, all_param # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 def prepare_model_for_training( - model: PreTrainedModel, + model: "PreTrainedModel", finetuning_type: str, output_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES -) -> PreTrainedModel: +) -> "PreTrainedModel": for name, param in model.named_parameters(): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): @@ -84,6 +94,9 @@ def prepare_model_for_training( model.config.use_cache = False # turn off when gradient checkpointing is enabled if finetuning_type != "full" and hasattr(model, output_layer_name): + if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"): + model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728) + output_layer: torch.nn.Linear = getattr(model, output_layer_name) input_dtype = output_layer.weight.dtype @@ -92,11 +105,8 @@ def prepare_model_for_training( def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x.to(input_dtype)).to(torch.float32) - new_output_layer = CastOutputToFloat(output_layer) - # adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot) - # (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819) - setattr(new_output_layer, "weight", output_layer.weight) - setattr(model, output_layer_name, new_output_layer) + setattr(model, output_layer_name, CastOutputToFloat(output_layer)) + return model diff --git a/src/llmtuner/extras/save_and_load.py b/src/llmtuner/extras/save_and_load.py index 781b9bb7..32dc651c 100644 --- a/src/llmtuner/extras/save_and_load.py +++ b/src/llmtuner/extras/save_and_load.py @@ -1,6 +1,6 @@ import os import torch -from typing import Dict, Optional +from typing import Dict from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.modeling_utils import load_sharded_checkpoint @@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger logger = get_logger(__name__) -def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]: - state_dict = model.state_dict() +def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor] = model.state_dict() filtered_state_dict = {} for k, v in model.named_parameters(): - if (not trainable_only) or v.requires_grad: + if v.requires_grad: filtered_state_dict[k] = state_dict[k].cpu().clone().detach() return filtered_state_dict diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index c0fda323..8cdc3511 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -11,37 +11,46 @@ class Template: use_history: bool def get_prompt( - self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = "", + eos_token: Optional[str] = "" ) -> str: r""" Returns a string containing prompt without response. """ - return "".join(self._format_example(query, history, prefix)) + return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix))) def get_dialog( - self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" - ) -> List[str]: + self, + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = "" + ) -> List[Tuple[str, str]]: r""" - Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. + Returns a list containing prompt-response pairs. """ - return self._format_example(query, history, prefix) + [resp] + result = self._format_example(query, history, prefix) + result[-1][-1] = resp + return result def _format_example( - self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" - ) -> List[str]: + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = "" + ) -> List[Tuple[str, str]]: prefix = prefix or self.prefix # use prefix if provided prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix history = history if (history and self.use_history) else [] - history = history + [(query, "")] - convs = [] - for turn_idx, (user_query, bot_resp) in enumerate(history): - if turn_idx == 0: - convs.append(prefix + self.prompt.format(query=user_query)) - convs.append(bot_resp) - else: - convs.append(self.sep + self.prompt.format(query=user_query)) - convs.append(bot_resp) - return convs[:-1] # drop last + history = history + [(query, "")] + convs = [ + [(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i] + for turn_idx, (query_i, resp_i) in enumerate(history) + ] + return convs templates: Dict[str, Template] = {} @@ -103,7 +112,7 @@ register_template( "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information.\n<>\n\n", prompt=" [INST] {query} [/INST] ", - sep="", + sep="", use_history=True ) @@ -131,7 +140,7 @@ register_template( prefix="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", prompt="USER: {query} ASSISTANT: ", - sep="", + sep="", use_history=True ) @@ -216,7 +225,7 @@ register_template( name="baichuan", prefix="", prompt="{query}", - sep="", + sep="", use_history=True ) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 3ef9ebcb..ce88d4d9 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -1,6 +1,6 @@ import os import json -from typing import List, Optional +from typing import List, Literal, Optional from dataclasses import dataclass, field @@ -16,10 +16,10 @@ class DatasetAttr: return self.dataset_name def __post_init__(self): - self.prompt_column = "instruction" - self.query_column = "input" - self.response_column = "output" - self.history_column = None + self.prompt = "instruction" + self.query = "input" + self.response = "output" + self.history = None @dataclass @@ -27,8 +27,11 @@ class DataArguments: """ Arguments pertaining to what data we are going to input our model for training and evaluation. """ + template: str = field( + metadata={"help": "Which template to use for constructing prompts in training and inference."} + ) dataset: Optional[str] = field( - default="alpaca_zh", + default="alpaca_en", metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} ) dataset_dir: Optional[str] = field( @@ -39,6 +42,18 @@ class DataArguments: default="train", metadata={"help": "Which dataset split to use for training and evaluation."} ) + streaming: Optional[bool] = field( + default=False, + metadata={"help": "Enable streaming mode."} + ) + buffer_size: Optional[int] = field( + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."} + ) + mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( + default="concat", + metadata={"help": "Strategy to use in dataset mixing."} + ) overwrite_cache: Optional[bool] = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets."} @@ -75,10 +90,6 @@ class DataArguments: default=0, metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} ) - prompt_template: Optional[str] = field( - default="default", - metadata={"help": "Which template to use for constructing prompts in training and inference."} - ) def init_for_training(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] @@ -111,9 +122,9 @@ class DataArguments: dataset_attr.source_prefix = prefix_list[i] if "columns" in dataset_info[name]: - dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) - dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) - dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) - dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) + dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query = dataset_info[name]["columns"].get("query", None) + dataset_attr.response = dataset_info[name]["columns"].get("response", None) + dataset_attr.history = dataset_info[name]["columns"].get("history", None) self.dataset_list.append(dataset_attr) diff --git a/src/llmtuner/hparams/general_args.py b/src/llmtuner/hparams/general_args.py index a97a4935..397d3019 100644 --- a/src/llmtuner/hparams/general_args.py +++ b/src/llmtuner/hparams/general_args.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field @dataclass class GeneralArguments: """ - Arguments pertaining to which techniques we are going to fine-tuning with. + Arguments pertaining to which stage we are going to perform. """ stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( default="sft", diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 5fddeb99..4afad13a 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -1,7 +1,7 @@ import os import torch +from typing import TYPE_CHECKING -from transformers.modeling_utils import PreTrainedModel from peft import ( PeftModel, TaskType, @@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME from llmtuner.extras.logging import get_logger from llmtuner.extras.save_and_load import load_trainable_params -from llmtuner.hparams import ModelArguments, FinetuningArguments + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from llmtuner.hparams import ModelArguments, FinetuningArguments logger = get_logger(__name__) def init_adapter( - model: PreTrainedModel, - model_args: ModelArguments, - finetuning_args: FinetuningArguments, + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", is_trainable: bool, is_mergeable: bool -) -> PreTrainedModel: +) -> "PreTrainedModel": r""" Initializes the adapters. diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 31509b72..921dbc11 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -1,6 +1,6 @@ import os import torch -from typing import Literal, Optional, Tuple +from typing import TYPE_CHECKING, Literal, Optional, Tuple from transformers import ( AutoConfig, @@ -16,11 +16,13 @@ from transformers.tokenization_utils import PreTrainedTokenizerBase from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params +from llmtuner.extras.misc import count_parameters, prepare_model_for_training from llmtuner.extras.save_and_load import load_valuehead_params -from llmtuner.hparams import ModelArguments, FinetuningArguments from llmtuner.tuner.core.adapter import init_adapter +if TYPE_CHECKING: + from llmtuner.hparams import ModelArguments, FinetuningArguments + logger = get_logger(__name__) @@ -33,8 +35,8 @@ require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") def load_model_and_tokenizer( - model_args: ModelArguments, - finetuning_args: FinetuningArguments, + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", is_trainable: Optional[bool] = False, stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: @@ -141,6 +143,9 @@ def load_model_and_tokenizer( model.requires_grad_(False) # fix all model params model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 - print_trainable_params(model) + trainable_params, all_param = count_parameters(model) + logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + )) return model, tokenizer diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 31e738f3..38c13f76 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -19,20 +19,39 @@ from llmtuner.hparams import ( logger = get_logger(__name__) +def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None): + if args is not None: + return parser.parse_dict(args) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + return parser.parse_args_into_dataclasses() + + +def parse_train_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]: + parser = HfArgumentParser(( + GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments + )) + return _parse_args(parser, args) + + +def parse_infer_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: + parser = HfArgumentParser(( + ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + )) + return _parse_args(parser, args) + + def get_train_args( args: Optional[Dict[str, Any]] = None ) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: - - parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments)) - - if args is not None: - model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses() + general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args) # Setup logging if training_args.should_log: @@ -73,13 +92,22 @@ def get_train_args( if training_args.do_train and (not training_args.fp16): logger.warning("We recommend enable fp16 mixed precision training.") - if data_args.prompt_template == "default": - logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") - - if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: - logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") + if ( + training_args.local_rank != -1 + and training_args.ddp_find_unused_parameters is None + and finetuning_args.finetuning_type == "lora" + ): + logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") training_args.ddp_find_unused_parameters = False + if data_args.max_samples is not None and data_args.streaming: + logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.") + data_args.streaming = False + + if data_args.dev_ratio > 1e-6 and data_args.streaming: + logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") + data_args.dev_ratio = 0 + training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if model_args.quantization_bit is not None: @@ -106,17 +134,7 @@ def get_train_args( def get_infer_args( args: Optional[Dict[str, Any]] = None ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: - - parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments)) - - if args is not None: - model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) - else: - model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() + model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ "Quantization is only compatible with the LoRA method." @@ -128,7 +146,4 @@ def get_infer_args( assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ "Quantized model only accepts a single checkpoint." - if data_args.prompt_template == "default": - logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") - return model_args, data_args, finetuning_args, generating_args diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index c9bb7043..928b0d9b 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -1,16 +1,19 @@ import os import torch -from typing import Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional from transformers import Seq2SeqTrainer -from transformers.trainer import TRAINING_ARGS_NAME +from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME from transformers.modeling_utils import PreTrainedModel, unwrap_model from peft import PeftModel +from trl import PreTrainedModelWrapper from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME from llmtuner.extras.logging import get_logger -from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params -from llmtuner.hparams import FinetuningArguments +from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params + +if TYPE_CHECKING: + from llmtuner.hparams import FinetuningArguments logger = get_logger(__name__) @@ -21,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer): Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. """ - def __init__(self, finetuning_args: FinetuningArguments, **kwargs): + def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): super().__init__(**kwargs) self.finetuning_args = finetuning_args self._remove_log() @@ -42,31 +45,35 @@ class PeftTrainer(Seq2SeqTrainer): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") + model = unwrap_model(self.model) + state_dict = state_dict or get_state_dict(model) - if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only) - backbone_model = getattr(model, "pretrained_model") - torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) - else: - backbone_model = model + if isinstance(model, PreTrainedModelWrapper): + model_params, v_head_params = {}, {} + for name in state_dict.keys(): + if name.startswith("pretrained_model."): + model_params[name.replace("pretrained_model.", "")] = state_dict[name] + elif name.startswith("v_head."): + v_head_params[name.replace("v_head.", "")] = state_dict[name] - if isinstance(backbone_model, PeftModel): # LoRA tuning - backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) - elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning - backbone_model.config.use_cache = True - backbone_model.save_pretrained( - output_dir, - state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")), - safe_serialization=self.args.save_safetensors - ) - backbone_model.config.use_cache = False - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) + torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) + state_dict = model_params + model = model.pretrained_model + + if isinstance(model, (PeftModel, PreTrainedModel)): + model.config.use_cache = True + model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors) + model.config.use_cache = False else: - logger.warning("No model to save.") + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: f.write(self.args.to_json_string() + "\n") + self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) def _load_best_model(self): @@ -76,16 +83,15 @@ class PeftTrainer(Seq2SeqTrainer): Subclass and override to inject custom behavior. It should not be directly used by external scripts. """ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") - model = unwrap_model(self.model) - backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model - if isinstance(backbone_model, PeftModel): - backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter) - if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): - model.v_head.load_state_dict({ - "summary.weight": getattr(model, "reward_head_weight"), - "summary.bias": getattr(model, "reward_head_bias") - }) + if isinstance(model, PreTrainedModelWrapper): + model.v_head.load_state_dict(torch.load( + os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu" + )) + model = model.pretrained_model + + if isinstance(model, PeftModel): + model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) else: # freeze/full-tuning - load_trainable_params(backbone_model, self.state.best_model_checkpoint) + load_trainable_params(model, self.state.best_model_checkpoint) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 668d2f44..f28cb93f 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -2,21 +2,25 @@ import os import math import torch from tqdm import tqdm -from typing import Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional -from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl +from transformers import TrainerState, TrainerControl from transformers.modeling_utils import PreTrainedModel from trl import PPOTrainer from trl.core import LengthSampler -from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, get_logits_processor -from llmtuner.hparams import FinetuningArguments + from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments + from llmtuner.extras.callbacks import LogCallback + from llmtuner.hparams import FinetuningArguments + logger = get_logger(__name__) @@ -27,9 +31,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): """ def __init__( self, - training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: List[LogCallback], + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: List["LogCallback"], **kwargs ): PPOTrainer.__init__(self, **kwargs) diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/tuner/ppo/utils.py index 55f67be1..984dcb08 100644 --- a/src/llmtuner/tuner/ppo/utils.py +++ b/src/llmtuner/tuner/ppo/utils.py @@ -1,11 +1,13 @@ import torch -from typing import Dict, List, Literal, Optional, Tuple -from trl import AutoModelForCausalLMWithValueHead +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple from llmtuner.extras.constants import LAYERNORM_NAMES +if TYPE_CHECKING: + from trl import AutoModelForCausalLMWithValueHead -def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: + +def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: if target == "reward": # save default head temporarily valuehead_state_dict = model.v_head.state_dict() setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) @@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def def cast_layernorm_dtype( - model: AutoModelForCausalLMWithValueHead, + model: "AutoModelForCausalLMWithValueHead", layer_norm_names: List[str] = LAYERNORM_NAMES, layer_norm_params: Optional[Dict[str, torch.Tensor]] = None -) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: +) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]: layer_norm_state_dict = {} diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 1257fd76..3a229c8c 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -2,26 +2,30 @@ # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py import math +from typing import TYPE_CHECKING from trl import PPOConfig from torch.optim import AdamW from typing import Optional, List -from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback +from transformers import DataCollatorForSeq2Seq from transformers.optimization import get_scheduler from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.ppo.trainer import PPOPeftTrainer +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + def run_ppo( - model_args: ModelArguments, - data_args: DataArguments, - training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: Optional[List[TrainerCallback]] = [LogCallback()] + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 59813532..1dbb6852 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -1,24 +1,27 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py import math -from typing import Optional, List -from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback +from typing import TYPE_CHECKING, Optional, List +from transformers import DataCollatorForSeq2Seq from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss -from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core.trainer import PeftTrainer +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + def run_pt( - model_args: ModelArguments, - data_args: DataArguments, - training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: Optional[List[TrainerCallback]] = [LogCallback()] + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/tuner/rm/collator.py index 57d6b54b..c0da0579 100644 --- a/src/llmtuner/tuner/rm/collator.py +++ b/src/llmtuner/tuner/rm/collator.py @@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): We generate 2 * n examples where the first n examples represent chosen examples and the last n examples represent rejected examples. """ - features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] + features = [ + {"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} + for key in ("accept_ids", "reject_ids") for feature in features + ] return super().__call__(features) diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index 584183c4..e69d48a8 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -1,13 +1,15 @@ import os import json import torch -from typing import Dict, List, Optional, Tuple, Union -from transformers.trainer import PredictionOutput -from transformers.modeling_utils import PreTrainedModel +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from llmtuner.extras.logging import get_logger from llmtuner.tuner.core.trainer import PeftTrainer +if TYPE_CHECKING: + from transformers.trainer import PredictionOutput + from transformers.modeling_utils import PreTrainedModel + logger = get_logger(__name__) @@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer): def compute_loss( self, - model: PreTrainedModel, + model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer): def save_predictions( self, - predict_results: PredictionOutput + predict_results: "PredictionOutput" ) -> None: r""" Saves model predictions to `output_dir`. diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index b7022c15..ec2b8ada 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -2,25 +2,27 @@ # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py -from typing import Optional, List -from transformers import Seq2SeqTrainingArguments, TrainerCallback +from typing import TYPE_CHECKING, Optional, List from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.trainer import PairwisePeftTrainer +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + def run_rm( - model_args: ModelArguments, - data_args: DataArguments, - training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: Optional[List[TrainerCallback]] = [LogCallback()] + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") diff --git a/src/llmtuner/tuner/sft/metric.py b/src/llmtuner/tuner/sft/metric.py index 8e67cc79..663b037d 100644 --- a/src/llmtuner/tuner/sft/metric.py +++ b/src/llmtuner/tuner/sft/metric.py @@ -1,7 +1,6 @@ import numpy as np from dataclasses import dataclass -from typing import Dict, Sequence, Tuple, Union -from transformers.tokenization_utils import PreTrainedTokenizer +from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union import jieba from rouge_chinese import Rouge @@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from llmtuner.extras.constants import IGNORE_INDEX +if TYPE_CHECKING: + from transformers.tokenization_utils import PreTrainedTokenizer + @dataclass class ComputeMetrics: @@ -16,7 +18,7 @@ class ComputeMetrics: Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. """ - tokenizer: PreTrainedTokenizer + tokenizer: "PreTrainedTokenizer" def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: r""" diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index 851767f8..8755bf59 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -3,13 +3,15 @@ import json import torch import numpy as np import torch.nn as nn -from typing import Any, Dict, List, Optional, Tuple, Union -from transformers.trainer import PredictionOutput +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.logging import get_logger from llmtuner.tuner.core.trainer import PeftTrainer +if TYPE_CHECKING: + from transformers.trainer import PredictionOutput + logger = get_logger(__name__) @@ -81,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer): def save_predictions( self, - predict_results: PredictionOutput + predict_results: "PredictionOutput" ) -> None: r""" Saves model predictions to `output_dir`. diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 6ba2f621..9a8feeb0 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -1,25 +1,28 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py -from typing import Optional, List -from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback +from typing import TYPE_CHECKING, Optional, List +from transformers import DataCollatorForSeq2Seq from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss -from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + def run_sft( - model_args: ModelArguments, - data_args: DataArguments, - training_args: Seq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: Optional[List[TrainerCallback]] = [LogCallback()] + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index fcbeaa79..d3cdbecc 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -54,7 +54,7 @@ class WebChatModel(ChatModel): checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit else None, - prompt_template=template, + template=template, source_prefix=source_prefix ) super().__init__(*get_infer_args(args)) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 99aa6da5..08f72d0a 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -111,7 +111,7 @@ class Runner: checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit else None, - prompt_template=template, + template=template, source_prefix=source_prefix, dataset_dir=dataset_dir, dataset=",".join(dataset), @@ -201,7 +201,7 @@ class Runner: checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit else None, - prompt_template=template, + template=template, source_prefix=source_prefix, dataset_dir=dataset_dir, dataset=",".join(dataset),