From 7b0629dac45d472e5d96953a4772fda4aacdd94b Mon Sep 17 00:00:00 2001 From: zhouwei <363232733@qq.com> Date: Mon, 6 May 2024 13:29:59 +0800 Subject: [PATCH 01/13] The training efficiency of the Ascend 910A has been significantly enhanced, leveraging the full computational power of the NPU (Neural Processing Unit) and the capabilities of torch_npu, a PyTorch library optimized for NPUs. This improvement has resulted in a remarkable tenfold increase in efficiency. Former-commit-id: 28ae947161d4670d4f865cbaad84397d47215a53 --- src/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/train.py b/src/train.py index 6a3212cb..e2609b66 100644 --- a/src/train.py +++ b/src/train.py @@ -1,3 +1,7 @@ +import os +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu from llmtuner.train.tuner import run_exp @@ -11,4 +15,6 @@ def _mp_fn(index): if __name__ == "__main__": + use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1'] + torch.npu.set_compile_mode(jit_compile=use_jit_compile) main() From 5a5d450648e8781083cd3b4ab96dad06806acecb Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 May 2024 20:37:21 +0800 Subject: [PATCH 02/13] fix #3728 Former-commit-id: cfaee8b4cf5f89d767a20a057d2335bd30ec83a2 --- src/llmtuner/extras/ploting.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py index e53f1f89..dea23bbe 100644 --- a/src/llmtuner/extras/ploting.py +++ b/src/llmtuner/extras/ploting.py @@ -21,6 +21,9 @@ def smooth(scalars: List[float]) -> List[float]: r""" EMA implementation according to TensorBoard. """ + if len(scalars) == 0: + return [] + last = scalars[0] smoothed = [] weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function @@ -32,6 +35,9 @@ def smooth(scalars: List[float]) -> List[float]: def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": + r""" + Plots loss curves in LlamaBoard. + """ plt.close("all") plt.switch_backend("agg") fig = plt.figure() @@ -51,6 +57,9 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: + r""" + Plots loss curves and saves the image. + """ plt.switch_backend("agg") with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: data = json.load(f) From 332f44fa43e46df112ca62035905f3a3bfaf77d8 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 14 May 2024 20:44:04 +0800 Subject: [PATCH 03/13] Apply suggestions from code review Co-authored-by: Huazhong Ji Former-commit-id: 0ac6e73f9971a9310026ddc609b5266cb1639b64 --- src/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index e2609b66..098ec1b5 100644 --- a/src/train.py +++ b/src/train.py @@ -1,6 +1,4 @@ import os -import torch -import torch_npu from torch_npu.contrib import transfer_to_npu from llmtuner.train.tuner import run_exp @@ -15,6 +13,7 @@ def _mp_fn(index): if __name__ == "__main__": - use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1'] - torch.npu.set_compile_mode(jit_compile=use_jit_compile) + if is_torch_npu_available(): + use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1'] + torch.npu.set_compile_mode(jit_compile=use_jit_compile) main() From fe586de3446ffbcc00b28ed654f069e3132780d5 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 14 May 2024 20:44:21 +0800 Subject: [PATCH 04/13] Apply suggestions from code review Co-authored-by: Huazhong Ji Former-commit-id: 9089bc70c8838cb80473e557a750855f7b7a7695 --- src/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 098ec1b5..00a7fa26 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,5 @@ import os -from torch_npu.contrib import transfer_to_npu +from transformers import is_torch_npu_available from llmtuner.train.tuner import run_exp From 082506eba8b63b98d30b584ddfbeef5c260103d7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 14 May 2024 20:47:52 +0800 Subject: [PATCH 05/13] Update train.py Former-commit-id: 1c3c4989022025db756965350ae0381fc9db32e5 --- src/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/train.py b/src/train.py index 00a7fa26..4cc21194 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,8 @@ import os + +import torch from transformers import is_torch_npu_available + from llmtuner.train.tuner import run_exp From ec9ed23cfd5c8e37a0da09a0d39679e382a4d4a3 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 May 2024 21:36:42 +0800 Subject: [PATCH 06/13] use robust envs Former-commit-id: c187b20aaa0a0eb7300d537fd9006bf977a02854 --- src/llmtuner/api/app.py | 2 +- src/llmtuner/extras/callbacks.py | 2 +- src/llmtuner/extras/misc.py | 2 +- src/llmtuner/webui/interface.py | 4 ++-- src/webui.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 6d06d1d0..21edab2f 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -51,7 +51,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": allow_methods=["*"], allow_headers=["*"], ) - api_key = os.environ.get("API_KEY", None) + api_key = os.environ.get("API_KEY") security = HTTPBearer(auto_error=False) async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 6d24b244..637b786d 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -53,7 +53,7 @@ class LogCallback(TrainerCallback): self.aborted = False self.do_train = False """ Web UI """ - self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0"))) + self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] if self.webui_mode: signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = LoggerHandler(output_dir) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 8ce25d18..53140efa 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -58,7 +58,7 @@ class AverageMeter: def check_dependencies() -> None: - if int(os.environ.get("DISABLE_VERSION_CHECK", "0")): + if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 91709d40..c5a30113 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -71,12 +71,12 @@ def create_web_demo() -> gr.Blocks: def run_web_ui() -> None: - gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") create_ui().queue().launch(share=gradio_share, server_name=server_name) def run_web_demo() -> None: - gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") create_web_demo().queue().launch(share=gradio_share, server_name=server_name) diff --git a/src/webui.py b/src/webui.py index 3f8690d0..7a43039d 100644 --- a/src/webui.py +++ b/src/webui.py @@ -4,7 +4,7 @@ from llmtuner.webui.interface import create_ui def main(): - gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") create_ui().queue().launch(share=gradio_share, server_name=server_name) From f5df1ceaf1343870dcb5d5b094c18ce6b343721e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 May 2024 23:32:53 +0800 Subject: [PATCH 07/13] add npu examples Former-commit-id: af343034dd31303be59678af9d1eae338864e884 --- examples/deepspeed/ds_z0_config.json | 18 ++++++++ examples/full_multi_gpu/multi_node.sh | 2 +- examples/full_multi_gpu/single_node.sh | 12 ++++-- examples/lora_multi_gpu/ds_zero3.sh | 12 ++++-- examples/lora_multi_npu/ds_zero0.sh | 15 +++++++ .../lora_multi_npu/llama3_lora_sft_ds.yaml | 42 +++++++++++++++++++ src/llmtuner/model/patcher.py | 9 +++- src/llmtuner/model/utils/attention.py | 4 +- src/train.py | 8 ---- 9 files changed, 103 insertions(+), 19 deletions(-) create mode 100644 examples/deepspeed/ds_z0_config.json create mode 100644 examples/lora_multi_npu/ds_zero0.sh create mode 100644 examples/lora_multi_npu/llama3_lora_sft_ds.yaml diff --git a/examples/deepspeed/ds_z0_config.json b/examples/deepspeed/ds_z0_config.json new file mode 100644 index 00000000..b7826b20 --- /dev/null +++ b/examples/deepspeed/ds_z0_config.json @@ -0,0 +1,18 @@ +{ + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + } +} \ No newline at end of file diff --git a/examples/full_multi_gpu/multi_node.sh b/examples/full_multi_gpu/multi_node.sh index 962409a1..34c038d4 100644 --- a/examples/full_multi_gpu/multi_node.sh +++ b/examples/full_multi_gpu/multi_node.sh @@ -6,7 +6,7 @@ RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run \ +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ --nproc_per_node $NPROC_PER_NODE \ --nnodes $NNODES \ --node_rank $RANK \ diff --git a/examples/full_multi_gpu/single_node.sh b/examples/full_multi_gpu/single_node.sh index 97f7af64..ac29c097 100644 --- a/examples/full_multi_gpu/single_node.sh +++ b/examples/full_multi_gpu/single_node.sh @@ -1,9 +1,15 @@ #!/bin/bash NPROC_PER_NODE=4 +NNODES=1 +RANK=0 +MASTER_ADDR=127.0.0.1 +MASTER_PORT=29500 -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run \ +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ --nproc_per_node $NPROC_PER_NODE \ - --nnodes 1 \ - --standalone \ + --nnodes $NNODES \ + --node_rank $RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ src/train.py examples/full_multi_gpu/llama3_full_sft.yaml diff --git a/examples/lora_multi_gpu/ds_zero3.sh b/examples/lora_multi_gpu/ds_zero3.sh index b8fd2640..90ea00dd 100644 --- a/examples/lora_multi_gpu/ds_zero3.sh +++ b/examples/lora_multi_gpu/ds_zero3.sh @@ -1,9 +1,15 @@ #!/bin/bash NPROC_PER_NODE=4 +NNODES=1 +RANK=0 +MASTER_ADDR=127.0.0.1 +MASTER_PORT=29500 -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run \ +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ --nproc_per_node $NPROC_PER_NODE \ - --nnodes 1 \ - --standalone \ + --nnodes $NNODES \ + --node_rank $RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml diff --git a/examples/lora_multi_npu/ds_zero0.sh b/examples/lora_multi_npu/ds_zero0.sh new file mode 100644 index 00000000..f849c5c9 --- /dev/null +++ b/examples/lora_multi_npu/ds_zero0.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +NPROC_PER_NODE=4 +NNODES=1 +RANK=0 +MASTER_ADDR=127.0.0.1 +MASTER_PORT=29500 + +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $NNODES \ + --node_rank $RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml diff --git a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml b/examples/lora_multi_npu/llama3_lora_sft_ds.yaml new file mode 100644 index 00000000..2e9c0558 --- /dev/null +++ b/examples/lora_multi_npu/llama3_lora_sft_ds.yaml @@ -0,0 +1,42 @@ +# model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + +# method +stage: sft +do_train: true +finetuning_type: lora +lora_target: q_proj,v_proj + +# ddp +ddp_timeout: 180000000 +deepspeed: examples/deepspeed/ds_z0_config.json + +# dataset +dataset: identity,alpaca_gpt4_en +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +# output +output_dir: saves/llama3-8b/lora/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +# train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 2 +learning_rate: 0.0001 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_steps: 0.1 +fp16: true + +# eval +val_size: 0.1 +per_device_eval_batch_size: 1 +evaluation_strategy: steps +eval_steps: 500 diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index fd99bd3b..b28a23d0 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -1,9 +1,10 @@ +import os from types import MethodType from typing import TYPE_CHECKING, Any, Dict import torch from peft import PeftModel -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger @@ -44,6 +45,10 @@ def patch_config( if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + if is_torch_npu_available(): + use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] + torch.npu.set_compile_mode(jit_compile=use_jit_compile) + configure_attn_implementation(config, model_args) configure_rope(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable) @@ -56,7 +61,7 @@ def patch_config( logger.info("Using KV cache for faster generation.") if getattr(config, "model_type", None) == "qwen": - setattr(config, "use_flash_attn", model_args.flash_attn) + setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) diff --git a/src/llmtuner/model/utils/attention.py b/src/llmtuner/model/utils/attention.py index f4686489..b52ddc86 100644 --- a/src/llmtuner/model/utils/attention.py +++ b/src/llmtuner/model/utils/attention.py @@ -22,7 +22,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model elif model_args.flash_attn == "sdpa": if not is_sdpa_available(): - logger.warning("Torch>=2.1.1 is required for SDPA attention.") + logger.warning("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" @@ -52,4 +52,4 @@ def print_attn_implementation(config: "PretrainedConfig") -> None: elif attn_implementation == "sdpa": logger.info("Using torch SDPA for faster training and inference.") else: - logger.info("Using vanilla Attention implementation.") + logger.info("Using vanilla attention implementation.") diff --git a/src/train.py b/src/train.py index 4cc21194..6a3212cb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,8 +1,3 @@ -import os - -import torch -from transformers import is_torch_npu_available - from llmtuner.train.tuner import run_exp @@ -16,7 +11,4 @@ def _mp_fn(index): if __name__ == "__main__": - if is_torch_npu_available(): - use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1'] - torch.npu.set_compile_mode(jit_compile=use_jit_compile) main() From 943779eabc6f14e83297494bc6d67b3f4e222c6a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 May 2024 23:55:49 +0800 Subject: [PATCH 08/13] update readme Former-commit-id: fc547ee591ef3cfc1bdbb8297a75a74f05c83c82 --- README.md | 26 +++++++++++++++++++++++--- README_zh.md | 24 ++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 90fcb295..a138d646 100644 --- a/README.md +++ b/README.md @@ -70,14 +70,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. + [24/05/13] We supported fine-tuning the **Yi-1.5** series models. [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. -[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details. -
Full Changelog +[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details. + [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage. @@ -328,7 +330,7 @@ Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, gal
For Windows users -If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. +If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. ```bash pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl @@ -338,6 +340,24 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
+
For Ascend NPU users + +To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** package and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. + +| Requirement | Minimum | Recommend | +| ------------ | ------- | --------- | +| CANN | 8.0.RC1 | 8.0.RC1 | +| torch | 2.2.0 | 2.2.0 | +| torch-npu | 2.2.0 | 2.2.0 | +| deepspeed | 0.13.2 | 0.13.2 | + +> [!NOTE] +> Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. +> +> If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations. + +
+ ### Data Preparation Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk. diff --git a/README_zh.md b/README_zh.md index 1d15515e..a0373711 100644 --- a/README_zh.md +++ b/README_zh.md @@ -70,14 +70,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 +[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。 + [24/05/13] 我们支持了 Yi-1.5 系列模型的微调。 [24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 -[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。 -
展开日志 +[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。 + [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。 @@ -338,6 +340,24 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
+
昇腾 NPU 用户指南 + +如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。 + +| 依赖项 | 至少 | 推荐 | +| ------------ | ------- | --------- | +| CANN | 8.0.RC1 | 8.0.RC1 | +| torch | 2.2.0 | 2.2.0 | +| torch-npu | 2.2.0 | 2.2.0 | +| deepspeed | 0.13.2 | 0.13.2 | + +> [!NOTE] +> 请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 +> +> 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。 + +
+ ### 数据准备 关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。 From be1114bb43a2804bc3cc390043bf94cf2918cc30 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 May 2024 23:57:08 +0800 Subject: [PATCH 09/13] update readme Former-commit-id: b96d84835f9237e7277bb86395e448348473d20f --- README.md | 7 +++---- README_zh.md | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index a138d646..826512c6 100644 --- a/README.md +++ b/README.md @@ -351,10 +351,9 @@ To utilize Ascend NPU devices for (distributed) training and inference, you need | torch-npu | 2.2.0 | 2.2.0 | | deepspeed | 0.13.2 | 0.13.2 | -> [!NOTE] -> Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. -> -> If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations. +Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. + +If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
diff --git a/README_zh.md b/README_zh.md index a0373711..d41ff13a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -351,10 +351,9 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl | torch-npu | 2.2.0 | 2.2.0 | | deepspeed | 0.13.2 | 0.13.2 | -> [!NOTE] -> 请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 -> -> 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。 +请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 + +如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。 From c4743674ab52a17e53c1acc931cf0d31b745f0c5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 15 May 2024 00:05:17 +0800 Subject: [PATCH 10/13] update examples Former-commit-id: 5bdad463875100e402329d47cd4c14bf9bc3b84b --- examples/README.md | 9 +++++++++ examples/README_zh.md | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/examples/README.md b/examples/README.md index 0838314a..4b4a8248 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,6 +7,7 @@ Make sure to execute these commands in the `LLaMA-Factory` directory. - [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu) - [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu) - [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus) +- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus) - [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus) - [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization) - [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models) @@ -124,6 +125,14 @@ bash examples/lora_multi_gpu/multi_node.sh bash examples/lora_multi_gpu/ds_zero3.sh ``` +### LoRA Fine-Tuning on Multiple NPUs + +#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 + +```bash +bash examples/lora_multi_npu/ds_zero0.sh +``` + ### Full-Parameter Fine-Tuning on Multiple GPUs #### Supervised Fine-Tuning with Accelerate on Single Node diff --git a/examples/README_zh.md b/examples/README_zh.md index 7fe43954..3b5b2dee 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -7,6 +7,7 @@ - [单 GPU LoRA 微调](#单-gpu-lora-微调) - [单 GPU QLoRA 微调](#单-gpu-qlora-微调) - [多 GPU LoRA 微调](#多-gpu-lora-微调) +- [多 NPU LoRA 微调](#多-npu-lora-微调) - [多 GPU 全参数微调](#多-gpu-全参数微调) - [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化) - [推理 LoRA 模型](#推理-lora-模型) @@ -124,6 +125,14 @@ bash examples/lora_multi_gpu/multi_node.sh bash examples/lora_multi_gpu/ds_zero3.sh ``` +### 多 NPU LoRA 微调 + +#### 使用 DeepSpeed ZeRO-0 训练 + +```bash +bash examples/lora_multi_npu/ds_zero0.sh +``` + ### 多 GPU 全参数微调 #### 使用 DeepSpeed 进行单节点训练 From 213ba09b24f495690c532ecb4ad3784ddb2fa845 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 15 May 2024 00:26:10 +0800 Subject: [PATCH 11/13] fix examples Former-commit-id: 7e69e71a52c736d0e42afbf61a3b3c22db606bc2 --- examples/deepspeed/ds_z0_config.json | 10 ++++++++++ examples/lora_multi_npu/ds_zero0.sh | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/deepspeed/ds_z0_config.json b/examples/deepspeed/ds_z0_config.json index b7826b20..ed326676 100644 --- a/examples/deepspeed/ds_z0_config.json +++ b/examples/deepspeed/ds_z0_config.json @@ -14,5 +14,15 @@ }, "bf16": { "enabled": "auto" + }, + "zero_optimization": { + "stage": 0, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true, + "round_robin_gradients": true } } \ No newline at end of file diff --git a/examples/lora_multi_npu/ds_zero0.sh b/examples/lora_multi_npu/ds_zero0.sh index f849c5c9..4ffaa1b0 100644 --- a/examples/lora_multi_npu/ds_zero0.sh +++ b/examples/lora_multi_npu/ds_zero0.sh @@ -12,4 +12,4 @@ ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \ --node_rank $RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT \ - src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml + src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml From ef167f839daadafe815505edeb8c163ef4237bae Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 15 May 2024 01:49:05 +0800 Subject: [PATCH 12/13] fix gen args Former-commit-id: 144801db09ec7f183ab455d7a88c76de7639333d --- src/llmtuner/chat/hf_engine.py | 31 ++++++++++++++--------- src/llmtuner/chat/vllm_engine.py | 43 +++++++++++++------------------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/llmtuner/chat/hf_engine.py b/src/llmtuner/chat/hf_engine.py index 97160d57..5cb8bfe4 100644 --- a/src/llmtuner/chat/hf_engine.py +++ b/src/llmtuner/chat/hf_engine.py @@ -65,12 +65,13 @@ class HuggingfaceEngine(BaseEngine): prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) - do_sample = input_kwargs.pop("do_sample", None) - temperature = input_kwargs.pop("temperature", None) - top_p = input_kwargs.pop("top_p", None) - top_k = input_kwargs.pop("top_k", None) - num_return_sequences = input_kwargs.pop("num_return_sequences", None) - repetition_penalty = input_kwargs.pop("repetition_penalty", None) + do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"]) + temperature = input_kwargs.pop("temperature", generating_args["temperature"]) + top_p = input_kwargs.pop("top_p", generating_args["top_p"]) + top_k = input_kwargs.pop("top_k", generating_args["top_k"]) + num_return_sequences = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"]) + length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"]) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) stop = input_kwargs.pop("stop", None) @@ -78,14 +79,16 @@ class HuggingfaceEngine(BaseEngine): if stop is not None: raise ValueError("Stop parameter is not supported in Huggingface engine yet.") + generating_args = generating_args.copy() generating_args.update( dict( - do_sample=do_sample if do_sample is not None else generating_args["do_sample"], - temperature=temperature or generating_args["temperature"], - top_p=top_p or generating_args["top_p"], - top_k=top_k or generating_args["top_k"], - num_return_sequences=num_return_sequences or 1, - repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + top_k=top_k, + num_return_sequences=num_return_sequences, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, pad_token_id=tokenizer.pad_token_id, ) @@ -94,6 +97,10 @@ class HuggingfaceEngine(BaseEngine): if isinstance(num_return_sequences, int) and num_return_sequences > 1: generating_args["do_sample"] = True + if not generating_args["do_sample"]: + generating_args.pop("temperature", None) + generating_args.pop("top_p", None) + if max_length: generating_args.pop("max_new_tokens", None) generating_args["max_length"] = max_length diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index d50e41aa..faf8c9fe 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -89,43 +89,34 @@ class VllmEngine(BaseEngine): ) prompt_length = len(prompt_ids) - temperature = input_kwargs.pop("temperature", None) - top_p = input_kwargs.pop("top_p", None) - top_k = input_kwargs.pop("top_k", None) - num_return_sequences = input_kwargs.pop("num_return_sequences", None) - repetition_penalty = input_kwargs.pop("repetition_penalty", None) + use_beam_search = self.generating_args["num_beams"] > 1 + temperature = input_kwargs.pop("temperature", self.generating_args["temperature"]) + top_p = input_kwargs.pop("top_p", self.generating_args["top_p"]) + top_k = input_kwargs.pop("top_k", self.generating_args["top_k"]) + num_return_sequences = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"]) + length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"]) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) stop = input_kwargs.pop("stop", None) - generating_args = self.generating_args.copy() - generating_args.update( - dict( - temperature=temperature or generating_args["temperature"], - top_p=top_p or generating_args["top_p"], - top_k=top_k or generating_args["top_k"], - num_return_sequences=num_return_sequences or 1, - repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], - ) - ) - if max_length: - generating_args["max_new_tokens"] = max_length - prompt_length + max_tokens = max_length - prompt_length if max_new_tokens: - generating_args["max_new_tokens"] = max_new_tokens + max_tokens = max_new_tokens sampling_params = SamplingParams( - n=generating_args["num_return_sequences"], - repetition_penalty=generating_args["repetition_penalty"], - temperature=generating_args["temperature"], - top_p=generating_args["top_p"], - top_k=generating_args["top_k"], - use_beam_search=generating_args["num_beams"] > 1, - length_penalty=generating_args["length_penalty"], + n=num_return_sequences, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + use_beam_search=use_beam_search, + length_penalty=length_penalty, stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, - max_tokens=generating_args["max_new_tokens"], + max_tokens=max_tokens, skip_special_tokens=True, ) From 967b9c0a498c1c527ae0bbfb6350535eee6b13e2 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 15 May 2024 02:17:54 +0800 Subject: [PATCH 13/13] fix bug in vllm engine Former-commit-id: 11bf282dcc0ee257f2c28f46cc1a8edcf62421dc --- src/llmtuner/chat/vllm_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index faf8c9fe..aaaad2f1 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -100,8 +100,9 @@ class VllmEngine(BaseEngine): max_new_tokens = input_kwargs.pop("max_new_tokens", None) stop = input_kwargs.pop("stop", None) + max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"] if max_length: - max_tokens = max_length - prompt_length + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 if max_new_tokens: max_tokens = max_new_tokens