update training resuming

Former-commit-id: 58f13e22da18babed0d2d4348474e07745da8fa5
This commit is contained in:
hiyouga 2023-08-18 01:41:17 +08:00
parent f128cfe880
commit fceca0bb6a
7 changed files with 57 additions and 29 deletions

View File

@ -12,6 +12,8 @@
## Changelog ## Changelog
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. [23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature). [23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
@ -158,6 +160,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
CUDA_VISIBLE_DEVICES=0 python src/train_web.py CUDA_VISIBLE_DEVICES=0 python src/train_web.py
``` ```
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
Currently the web UI only supports training on **a single GPU**. Currently the web UI only supports training on **a single GPU**.
### Pre-Training ### Pre-Training

View File

@ -12,6 +12,8 @@
## 更新日志 ## 更新日志
[23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能。
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 [23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。 [23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
@ -24,7 +26,7 @@
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。 [23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的**一体化浏览器界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 [23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。 [23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
@ -152,12 +154,14 @@ pip install -r requirements.txt
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
``` ```
### 浏览器一键微调/测试 ### 浏览器一体化界面
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py CUDA_VISIBLE_DEVICES=0 python src/train_web.py
``` ```
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
目前网页 UI 仅支持**单卡训练**。 目前网页 UI 仅支持**单卡训练**。
### 预训练 ### 预训练
@ -451,6 +455,8 @@ python src/export_model.py \
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license) - [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) - [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## 引用 ## 引用

View File

@ -121,6 +121,9 @@ def load_model_and_tokenizer(
# Quantization configurations (using bitsandbytes library). # Quantization configurations (using bitsandbytes library).
is_mergeable = True is_mergeable = True
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True config_kwargs["load_in_8bit"] = True
@ -144,7 +147,7 @@ def load_model_and_tokenizer(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_to_load, model_to_load,
config=config, config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs **config_kwargs
) )

View File

@ -5,6 +5,7 @@ import datasets
import transformers import transformers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ( from llmtuner.hparams import (
@ -97,30 +98,33 @@ def get_train_args(
if general_args.stage != "sft" and training_args.predict_with_generate: if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.") raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO training can only be performed with the LoRA method.") raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stage can only be performed at training.") raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage == "ppo" and model_args.reward_model is None: if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("Reward model is necessary for PPO training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if general_args.stage == "ppo" and data_args.streaming: if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.") raise ValueError("Streaming mode does not suppport PPO training currently.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming: if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.") raise ValueError("Streaming mode should have an integer val size.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
@ -134,9 +138,15 @@ def get_train_args(
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16): if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable fp16 mixed precision training.") logger.warning("We recommend enable mixed precision training.")
# postprocess data_args
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
# postprocess training_args
if ( if (
training_args.local_rank != -1 training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
@ -145,12 +155,26 @@ def get_train_args(
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming: if training_args.optim == "adamw_hf":
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") training_args.optim = "adamw_torch" # suppress warning
data_args.max_samples = None
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)
# postprocess model_args
if training_args.bf16: if training_args.bf16:
if not torch.cuda.is_bf16_supported(): if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.") raise ValueError("Current device does not support bf16 training.")

View File

@ -50,7 +50,7 @@ def run_dpo(
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train() train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()

View File

@ -39,7 +39,7 @@ def run_pt(
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train() train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()

View File

@ -1,5 +1,4 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
import os
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
@ -11,14 +10,11 @@ from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import reset_logging, get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
def run_sft( def run_sft(
model_args: "ModelArguments", model_args: "ModelArguments",
@ -62,12 +58,7 @@ def run_sft(
# Training # Training
if training_args.do_train: if training_args.do_train:
checkpoint = None train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()