update ppo trainer

Former-commit-id: 5021062493ed63ad1f6133cfb543e4e7f528d2cc
This commit is contained in:
hiyouga 2023-11-20 21:39:15 +08:00
parent d72f123851
commit f06c4c8f7a
7 changed files with 68 additions and 41 deletions

View File

@ -323,6 +323,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 training.
#### DPO Training
```bash
@ -418,7 +421,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
</details>
### Export model
### Merge LoRA weights and export model
```bash
python src/export_model.py \
@ -439,7 +442,7 @@ python src/api_demo.py \
--checkpoint_dir path_to_checkpoint
```
> [!NOTE]
> [!TIP]
> Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
@ -491,10 +494,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
--predict_with_generate \
--fp16
```
> [!NOTE]
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 predict.
> [!TIP]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
## Projects using LLaMA Factory
@ -504,7 +511,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
> [!NOTE]
> [!TIP]
> If you have a project that should be incorporated, please contact via email or create a pull request.
## License

View File

@ -319,9 +319,13 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
--plot_loss \
--fp16
```
> [!WARNING]
> 如果在 fp16 精度下训练 LLaMA-2 模型,请使用 `--per_device_eval_batch_size=1`
#### DPO 训练
```bash
@ -417,7 +421,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
</details>
### 导出微调后的完整模型
### 合并 LoRA 权重并导出完整模型
```bash
python src/export_model.py \
@ -438,7 +442,7 @@ python src/api_demo.py \
--checkpoint_dir path_to_checkpoint
```
> [!NOTE]
> [!TIP]
> 关于 API 文档请见 `http://localhost:8000/docs`
### 命令行测试
@ -490,10 +494,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
--predict_with_generate \
--fp16
```
> [!NOTE]
> [!WARNING]
> 如果在 fp16 精度下推理 LLaMA-2 模型,请使用 `--per_device_eval_batch_size=1`
> [!TIP]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
## 使用了 LLaMA Factory 的项目
@ -503,7 +511,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
> [!NOTE]
> [!TIP]
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
## 协议

View File

@ -16,7 +16,10 @@ try:
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
if TYPE_CHECKING:
from transformers import HfArgumentParser

View File

@ -8,10 +8,6 @@ class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
@ -22,6 +18,10 @@ class FreezeArguments:
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
Others choices: the same as LLaMA."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
@dataclass
@ -29,9 +29,9 @@ class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
)
lora_alpha: Optional[float] = field(
default=None,
@ -41,6 +41,10 @@ class LoraArguments:
default=0.1,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
@ -51,10 +55,6 @@ class LoraArguments:
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
Others choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
@ -70,13 +70,17 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
ppo_buffer_size: Optional[int] = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
)
ppo_epochs: Optional[int] = field(
default=4,
metadata={"help": "Number of optimisation epochs per batch of samples"},
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
)
ppo_score_norm: Optional[bool] = field(
default=False,

View File

@ -202,7 +202,7 @@ def load_model_and_tokenizer(
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
setattr(model, "tie_weights", MethodType(lambda _: None, model))
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)

View File

@ -82,13 +82,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
* self.finetuning_args.ppo_buffer_size
* self.args.world_size
)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
steps_in_epoch = self.args.max_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
@ -103,13 +106,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
))
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps))
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)

View File

@ -39,11 +39,12 @@ def run_ppo(
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
@ -62,9 +63,7 @@ def run_ppo(
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(