diff --git a/README.md b/README.md index 88b00ec8..7e5359b6 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ [23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. +[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details. + [23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model. [23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. @@ -32,7 +34,7 @@ [23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. -[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature) +[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models. [23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model. @@ -312,8 +314,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --predict_with_generate ``` -If you want to predict the samples with empty responses, please kindly fill the `response` column with **dummy tokens** to ensure the sample will not be discarded throughout the preprocessing phase. - ### API Demo ```bash @@ -373,7 +373,7 @@ Please follow the model licenses to use the corresponding model weights: - [LLaMA-2](https://ai.meta.com/llama/license/) - [BLOOM](https://huggingface.co/spaces/bigscience/license) - [Falcon](LICENSE) -- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) +- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [InternLM](https://github.com/InternLM/InternLM#open-source-license) ## Citation diff --git a/README_zh.md b/README_zh.md index 6b4773eb..2e44197a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,6 +14,8 @@ [23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。 +[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。 + [23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。 [23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 @@ -312,8 +314,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --predict_with_generate ``` -如果需要预测的样本没有标签,请首先在 `response` 列中填入一些占位符,以免样本在预处理阶段被丢弃。 - ### API 服务 ```bash @@ -373,7 +373,7 @@ python src/export_model.py \ - [LLaMA-2](https://ai.meta.com/llama/license/) - [BLOOM](https://huggingface.co/spaces/bigscience/license) - [Falcon](LICENSE) -- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) +- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [InternLM](https://github.com/InternLM/InternLM#open-source-license) ## 引用 diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 146ad353..7ce46808 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,4 +1,4 @@ from llmtuner.chat import ChatModel -__version__ = "0.1.3" +__version__ = "0.1.4" diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 22b657e3..3997c7ef 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -93,9 +93,11 @@ def load_model_and_tokenizer( ) is_mergeable = False - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None: + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} + if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": model_to_load = model_args.checkpoint_dir[0] else: diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 1dc60039..f9f38058 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -32,9 +32,9 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) def parse_train_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]: +) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: parser = HfArgumentParser(( - GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments + ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments )) return _parse_args(parser, args) @@ -51,7 +51,7 @@ def parse_infer_args( def get_train_args( args: Optional[Dict[str, Any]] = None ) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: - general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args) + model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args) # Setup logging if training_args.should_log: @@ -79,6 +79,12 @@ def get_train_args( assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ "Quantization is only compatible with the LoRA method." + assert not (training_args.max_steps == -1 and data_args.streaming), \ + "Please specify `max_steps` in streaming mode." + + assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ + "Streaming mode does not support evaluation currently." + if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." @@ -108,12 +114,6 @@ def get_train_args( logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") data_args.dev_ratio = 0 - assert not (training_args.max_steps == -1 and data_args.streaming), \ - "Please specify `max_steps` in streaming mode." - - assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ - "Streaming mode does not support evaluation currently." - training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if model_args.quantization_bit is not None: diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 0fa6a3d8..6fcfc652 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -1,16 +1,17 @@ -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple import gradio as gr -from gradio.blocks import Block -from gradio.components import Component -from llmtuner.webui.chat import WebChatModel +if TYPE_CHECKING: + from gradio.blocks import Block + from gradio.components import Component + from llmtuner.webui.chat import WebChatModel def create_chat_box( - chat_model: WebChatModel, + chat_model: "WebChatModel", visible: Optional[bool] = False -) -> Tuple[Block, Component, Component, Dict[str, Component]]: +) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: with gr.Box(visible=visible) as chat_box: chatbot = gr.Chatbot() diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index 4445f39c..9787b36a 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -1,10 +1,12 @@ import gradio as gr -from gradio.blocks import Block -from gradio.components import Component -from typing import Tuple +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from gradio.blocks import Block + from gradio.components import Component -def create_preview_box() -> Tuple[Block, Component, Component, Component]: +def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]: with gr.Box(visible=False, elem_classes="modal-box") as preview_box: with gr.Row(): preview_count = gr.Number(interactive=False) diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 4765ef66..29b590ae 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -1,14 +1,16 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict import gradio as gr -from gradio.components import Component from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box -from llmtuner.webui.runner import Runner from llmtuner.webui.utils import can_preview, get_preview +if TYPE_CHECKING: + from gradio.components import Component + from llmtuner.webui.runner import Runner -def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: + +def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index 72b66e71..6e27ff16 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,11 +1,13 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict import gradio as gr -from gradio.components import Component from llmtuner.webui.utils import export_model +if TYPE_CHECKING: + from gradio.components import Component -def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: + +def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: with gr.Row(): save_dir = gr.Textbox() max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 6824c36b..40e0323e 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -1,13 +1,15 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict import gradio as gr -from gradio.components import Component from llmtuner.webui.chat import WebChatModel from llmtuner.webui.components.chatbot import create_chat_box +if TYPE_CHECKING: + from gradio.components import Component -def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: + +def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: with gr.Row(): load_btn = gr.Button() unload_btn = gr.Button() diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/sft.py index aa2b7a1a..f25d3b2a 100644 --- a/src/llmtuner/webui/components/sft.py +++ b/src/llmtuner/webui/components/sft.py @@ -1,16 +1,18 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict from transformers.trainer_utils import SchedulerType import gradio as gr -from gradio.components import Component from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box -from llmtuner.webui.runner import Runner from llmtuner.webui.utils import can_preview, get_preview, gen_plot +if TYPE_CHECKING: + from gradio.components import Component + from llmtuner.webui.runner import Runner -def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: + +def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index f57b99e1..4fc5b506 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -1,15 +1,17 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict import gradio as gr -from gradio.components import Component from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates from llmtuner.webui.common import list_checkpoint, get_model_path, save_config from llmtuner.webui.utils import can_quantize +if TYPE_CHECKING: + from gradio.components import Component -def create_top() -> Dict[str, Component]: + +def create_top() -> Dict[str, "Component"]: available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] with gr.Row(): diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index 28c40cad..d523c03f 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -1,15 +1,17 @@ import gradio as gr -from typing import Any, Dict, List -from gradio.components import Component +from typing import TYPE_CHECKING, Any, Dict, List from llmtuner.webui.common import get_model_path, list_dataset, load_config from llmtuner.webui.locales import LOCALES from llmtuner.webui.utils import get_time +if TYPE_CHECKING: + from gradio.components import Component + class Manager: - def __init__(self, elem_list: List[Dict[str, Component]]): + def __init__(self, elem_list: List[Dict[str, "Component"]]): self.elem_list = elem_list def gen_refresh(self) -> Dict[str, Any]: @@ -24,7 +26,7 @@ class Manager: return refresh_dict - def gen_label(self, lang: str) -> Dict[Component, dict]: + def gen_label(self, lang: str) -> Dict["Component", dict]: update_dict = {} refresh_dict = self.gen_refresh()