From ad0304e147b63ffef3e6c3877a61970174658d43 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 25 Jun 2024 02:31:44 +0800 Subject: [PATCH] fix #4379 Former-commit-id: cc016461e63a570142b56d50a5d11e55a96ab8db --- src/llamafactory/train/tuner.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 788b4c4f..a02fff22 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch @@ -19,6 +21,7 @@ from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras.callbacks import LogCallback +from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.logging import get_logger from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer @@ -98,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: safe_serialization=(not model_args.export_legacy_format), ) + if finetuning_args.stage == "rm": + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + try: tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left"