mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix PPO trainer #551 , update readme
Former-commit-id: 90205244186df558cd6b0000728d638348db3a10
This commit is contained in:
parent
e93e9641f5
commit
03edfd07e7
29
README.md
29
README.md
@ -171,11 +171,12 @@ Currently the web UI only supports training on **a single GPU**.
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage pt \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_pt_checkpoint \
|
--output_dir path_to_pt_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@ -194,11 +195,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@ -217,11 +219,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
@ -230,7 +233,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-6 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
@ -241,11 +244,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
@ -266,11 +270,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage dpo \
|
--stage dpo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_dpo_checkpoint \
|
--output_dir path_to_dpo_checkpoint \
|
||||||
@ -364,7 +369,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
@ -375,7 +380,7 @@ python src/export_model.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@ -387,7 +392,7 @@ Visit `http://localhost:8000/docs` for API documentation.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@ -397,7 +402,7 @@ python src/cli_demo.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@ -408,7 +413,7 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
@ -427,7 +432,7 @@ We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
|
29
README_zh.md
29
README_zh.md
@ -171,11 +171,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage pt \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_pt_checkpoint \
|
--output_dir path_to_pt_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@ -194,11 +195,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@ -217,11 +219,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_zh \
|
--dataset comparison_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
@ -230,7 +233,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-6 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
@ -241,11 +244,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
@ -265,11 +269,12 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage dpo \
|
--stage dpo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_zh \
|
--dataset comparison_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_dpo_checkpoint \
|
--output_dir path_to_dpo_checkpoint \
|
||||||
@ -363,7 +368,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
@ -374,7 +379,7 @@ python src/export_model.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@ -386,7 +391,7 @@ python src/api_demo.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@ -396,7 +401,7 @@ python src/cli_demo.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@ -407,7 +412,7 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
@ -426,7 +431,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -8,7 +8,7 @@ class FinetuningArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
)
|
)
|
||||||
@ -49,7 +49,7 @@ class FinetuningArguments:
|
|||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||||
)
|
)
|
||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default="q_proj,v_proj",
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
@ -77,7 +77,7 @@ class FinetuningArguments:
|
|||||||
|
|
||||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||||
|
|
||||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
|
@ -125,6 +125,9 @@ def get_train_args(
|
|||||||
if training_args.do_train and training_args.predict_with_generate:
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
|
||||||
|
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||||
|
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
|||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
@ -152,10 +152,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
if length_sampler is not None:
|
if length_sampler is not None:
|
||||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||||
|
|
||||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
|
||||||
|
|
||||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
import torch
|
from typing import TYPE_CHECKING, Literal
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
@ -18,22 +15,3 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
|||||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def cast_layernorm_dtype(
|
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
|
||||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
|
||||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
|
||||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
|
||||||
|
|
||||||
layer_norm_state_dict = {}
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
|
||||||
if layer_norm_params is not None:
|
|
||||||
param.data = layer_norm_params[name] # restore float32 weights
|
|
||||||
else:
|
|
||||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
|
||||||
param.data = param.data.to(torch.float16)
|
|
||||||
|
|
||||||
return model, layer_norm_state_dict
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user