diff --git a/README.md b/README.md index 78175060..94238a7b 100644 --- a/README.md +++ b/README.md @@ -457,7 +457,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \ "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 - }, + }, "zero_optimization": { "stage": 2, "allgather_partitions": true, diff --git a/README_zh.md b/README_zh.md index ef07d5b7..76205359 100644 --- a/README_zh.md +++ b/README_zh.md @@ -457,7 +457,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \ "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 - }, + }, "zero_optimization": { "stage": 2, "allgather_partitions": true, diff --git a/src/llmtuner/eval/evaluator.py b/src/llmtuner/eval/evaluator.py index 1fbd40ee..f0d28afb 100644 --- a/src/llmtuner/eval/evaluator.py +++ b/src/llmtuner/eval/evaluator.py @@ -3,7 +3,6 @@ import os import json import torch -import inspect import tiktoken import numpy as np from tqdm import tqdm, trange @@ -46,16 +45,11 @@ class Evaluator: return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] def eval(self) -> None: - if "token" in inspect.signature(cached_file).parameters: - kwargs = {"token": self.model_args.hf_hub_token} - elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0 - kwargs = {"use_auth_token": self.model_args.hf_hub_token} - mapping = cached_file( path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task), filename="mapping.json", cache_dir=self.model_args.cache_dir, - **kwargs + token=self.model_args.hf_hub_token ) with open(mapping, "r", encoding="utf-8") as f: diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 44afacf0..2139f5db 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -1,17 +1,19 @@ import os import json import time -from typing import TYPE_CHECKING +import torch +from typing import TYPE_CHECKING, Dict from datetime import timedelta from transformers import PreTrainedModel, TrainerCallback -from transformers.modeling_utils import custom_object_save, unwrap_model +from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR from peft import PeftModel -from llmtuner.extras.constants import LOG_FILE_NAME +from llmtuner.extras.constants import LOG_FILE_NAME, V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME from llmtuner.extras.logging import get_logger + if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl from trl import AutoModelForCausalLMWithValueHead @@ -20,31 +22,66 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def _save_model_with_valuehead( +def _fix_valuehead_checkpoint( model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool ) -> None: - if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): - model.pretrained_model.config.save_pretrained(output_dir) - if model.pretrained_model.can_generate(): - model.pretrained_model.generation_config.save_pretrained(output_dir) + r""" + The model is already unwrapped. - if getattr(model, "is_peft_model", False): - model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization) - elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model - custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config) + There are three cases: + 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} + 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} + 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} + + We assume `stage3_gather_16bit_weights_on_model_save=true`. + """ + if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + return + + if safe_serialization: + from safetensors import safe_open + from safetensors.torch import save_file + path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) + with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: + state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + else: + path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) + state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + + decoder_state_dict = {} + v_head_state_dict = {} + for name, param in state_dict.items(): + if name.startswith("v_head."): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace("pretrained_model.", "")] = param + + os.remove(path_to_checkpoint) + model.pretrained_model.save_pretrained( + output_dir, + state_dict=decoder_state_dict or None, + safe_serialization=safe_serialization + ) + + if safe_serialization: + save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + + logger.info("Value head model saved at: {}".format(output_dir)) -class SavePeftModelCallback(TrainerCallback): +class FixValueHeadModelCallback(TrainerCallback): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called after a checkpoint save. """ if args.should_save: - _save_model_with_valuehead( - model=unwrap_model(kwargs.pop("model")), + _fix_valuehead_checkpoint( + model=kwargs.pop("model"), output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), safe_serialization=args.save_safetensors ) @@ -54,10 +91,8 @@ class SavePeftModelCallback(TrainerCallback): Event called at the end of training. """ if args.should_save: - _save_model_with_valuehead( - model=unwrap_model(kwargs.pop("model")), - output_dir=args.output_dir, - safe_serialization=args.save_safetensors + _fix_valuehead_checkpoint( + model=kwargs.pop("model"), output_dir=args.output_dir, safe_serialization=args.save_safetensors ) diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 1dbd6b9d..a81db3a7 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -40,6 +40,10 @@ TRAINING_STAGES = { "Pre-Training": "pt" } +V_HEAD_WEIGHTS_NAME = "v_head.bin" + +V_HEAD_SAFE_WEIGHTS_NAME = "v_head.safetensors" + class DownloadSource(str, Enum): DEFAULT = "hf" MODELSCOPE = "ms" diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 824a1e79..14bd4c59 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -3,8 +3,8 @@ import inspect from typing import TYPE_CHECKING, Any, Dict, List from transformers import PreTrainedModel from transformers.utils import cached_file -from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME +from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import get_current_device @@ -103,22 +103,20 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> try: from safetensors import safe_open - vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) + vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) with safe_open(vhead_file, framework="pt", device="cpu") as f: - return { - "v_head.summary.weight": f.get_tensor("v_head.summary.weight"), - "v_head.summary.bias": f.get_tensor("v_head.summary.bias") - } + return {key: f.get_tensor(key) for key in f.keys()} except Exception as err: - logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err))) + logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) try: - vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) + vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) return torch.load(vhead_file, map_location="cpu") except Exception as err: - logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err))) + logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) - logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)) + logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) + logger.info("Ignore these messages if you are not resuming the training of a value head model.") return None diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index ec9e3fdf..31cab7c0 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -8,11 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_pt_utils import remove_dummy_checkpoint from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits -from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback +from llmtuner.extras.callbacks import LogCallback, FixValueHeadModelCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model @@ -60,7 +61,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.accelerator.state, "deepspeed_plugin" ) self.log_callback, self.save_callback = callbacks[0], callbacks[1] - assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) + assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback) if self.args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") @@ -369,9 +370,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer): " use zero_to_fp32.py to recover weights" ) self._save(output_dir, state_dict={}) - for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint - file = os.path.join(output_dir, filename) - if os.path.isfile(file): - os.remove(file) - - self.model.save_checkpoint(output_dir) # wrapped model + remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + self.model.save_checkpoint(output_dir) diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 933f69db..1cbbd8da 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -8,7 +8,7 @@ from transformers import DataCollatorWithPadding from transformers.optimization import get_scheduler from llmtuner.data import get_dataset, preprocess_dataset -from llmtuner.extras.callbacks import SavePeftModelCallback +from llmtuner.extras.callbacks import FixValueHeadModelCallback from llmtuner.extras.ploting import plot_loss from llmtuner.model import load_model_and_tokenizer from llmtuner.train.utils import create_ref_model, create_reward_model @@ -79,7 +79,7 @@ def run_ppo( training_args=training_args, finetuning_args=finetuning_args, generating_args=generating_args, - callbacks=callbacks + [SavePeftModelCallback()], + callbacks=callbacks + [FixValueHeadModelCallback()], reward_model=reward_model, config=ppo_config, model=model, diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index 944024ab..f720aaf2 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments from llmtuner.data import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.callbacks import SavePeftModelCallback +from llmtuner.extras.callbacks import FixValueHeadModelCallback from llmtuner.extras.ploting import plot_loss from llmtuner.model import load_model_and_tokenizer from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding @@ -40,7 +40,7 @@ def run_rm( args=training_args, tokenizer=tokenizer, data_collator=data_collator, - callbacks=callbacks + [SavePeftModelCallback()], + callbacks=callbacks + [FixValueHeadModelCallback()], compute_metrics=compute_accuracy, **split_dataset(dataset, data_args, training_args) )