diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 2139f5db..17ab5dc1 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -1,78 +1,23 @@ import os import json import time -import torch -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from datetime import timedelta - -from transformers import PreTrainedModel, TrainerCallback -from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME +from transformers import TrainerCallback from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR -from peft import PeftModel -from llmtuner.extras.constants import LOG_FILE_NAME, V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME +from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.logging import get_logger +from llmtuner.extras.misc import fix_valuehead_checkpoint if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl - from trl import AutoModelForCausalLMWithValueHead logger = get_logger(__name__) -def _fix_valuehead_checkpoint( - model: "AutoModelForCausalLMWithValueHead", - output_dir: str, - safe_serialization: bool -) -> None: - r""" - The model is already unwrapped. - - There are three cases: - 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} - 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} - 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} - - We assume `stage3_gather_16bit_weights_on_model_save=true`. - """ - if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): - return - - if safe_serialization: - from safetensors import safe_open - from safetensors.torch import save_file - path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) - with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: - state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} - else: - path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) - state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") - - decoder_state_dict = {} - v_head_state_dict = {} - for name, param in state_dict.items(): - if name.startswith("v_head."): - v_head_state_dict[name] = param - else: - decoder_state_dict[name.replace("pretrained_model.", "")] = param - - os.remove(path_to_checkpoint) - model.pretrained_model.save_pretrained( - output_dir, - state_dict=decoder_state_dict or None, - safe_serialization=safe_serialization - ) - - if safe_serialization: - save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) - else: - torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) - - logger.info("Value head model saved at: {}".format(output_dir)) - - class FixValueHeadModelCallback(TrainerCallback): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): @@ -80,21 +25,12 @@ class FixValueHeadModelCallback(TrainerCallback): Event called after a checkpoint save. """ if args.should_save: - _fix_valuehead_checkpoint( + fix_valuehead_checkpoint( model=kwargs.pop("model"), output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), safe_serialization=args.save_safetensors ) - def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): - r""" - Event called at the end of training. - """ - if args.should_save: - _fix_valuehead_checkpoint( - model=kwargs.pop("model"), output_dir=args.output_dir, safe_serialization=args.save_safetensors - ) - class LogCallback(TrainerCallback): diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 3455b5e2..dee101ec 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,14 +1,21 @@ import gc import os import torch -from typing import TYPE_CHECKING, Tuple -from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList +from typing import TYPE_CHECKING, Dict, Tuple +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel from transformers.utils import ( + WEIGHTS_NAME, + SAFE_WEIGHTS_NAME, is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available, is_torch_xpu_available ) +from peft import PeftModel + +from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME +from llmtuner.extras.logging import get_logger + _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() try: @@ -18,9 +25,13 @@ except: if TYPE_CHECKING: + from trl import AutoModelForCausalLMWithValueHead from llmtuner.hparams import ModelArguments +logger = get_logger(__name__) + + class AverageMeter: r""" Computes and stores the average and current value. @@ -63,6 +74,57 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param +def fix_valuehead_checkpoint( + model: "AutoModelForCausalLMWithValueHead", + output_dir: str, + safe_serialization: bool +) -> None: + r""" + The model is already unwrapped. + + There are three cases: + 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} + 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} + 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} + + We assume `stage3_gather_16bit_weights_on_model_save=true`. + """ + if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + return + + if safe_serialization: + from safetensors import safe_open + from safetensors.torch import save_file + path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) + with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: + state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + else: + path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) + state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + + decoder_state_dict = {} + v_head_state_dict = {} + for name, param in state_dict.items(): + if name.startswith("v_head."): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace("pretrained_model.", "")] = param + + os.remove(path_to_checkpoint) + model.pretrained_model.save_pretrained( + output_dir, + state_dict=decoder_state_dict or None, + safe_serialization=safe_serialization + ) + + if safe_serialization: + save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + + logger.info("Value head model saved at: {}".format(output_dir)) + + def get_current_device() -> torch.device: r""" Gets the current available device. diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 1cbbd8da..10c6a227 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -9,6 +9,7 @@ from transformers.optimization import get_scheduler from llmtuner.data import get_dataset, preprocess_dataset from llmtuner.extras.callbacks import FixValueHeadModelCallback +from llmtuner.extras.misc import fix_valuehead_checkpoint from llmtuner.extras.ploting import plot_loss from llmtuner.model import load_model_and_tokenizer from llmtuner.train.utils import create_ref_model, create_reward_model @@ -95,6 +96,8 @@ def run_ppo( if training_args.do_train: ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) ppo_trainer.save_model() + if training_args.should_save: + fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) ppo_trainer.save_state() # must be called after save_model to have a folder if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index f720aaf2..52070027 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -5,6 +5,7 @@ from transformers import Seq2SeqTrainingArguments from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import FixValueHeadModelCallback +from llmtuner.extras.misc import fix_valuehead_checkpoint from llmtuner.extras.ploting import plot_loss from llmtuner.model import load_model_and_tokenizer from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding @@ -49,6 +50,8 @@ def run_rm( if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) trainer.save_model() + if training_args.should_save: + fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state()