From bb7bf51554d4ba8432333c35a5e3b52705955ede Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Thu, 26 Jun 2025 13:55:42 +0800 Subject: [PATCH] Merge commit from fork --- src/llamafactory/model/model_utils/valuehead.py | 2 +- src/llamafactory/train/callbacks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py index 137c6b7d..7409a22e 100644 --- a/src/llamafactory/model/model_utils/valuehead.py +++ b/src/llamafactory/model/model_utils/valuehead.py @@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> try: vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) - return torch.load(vhead_file, map_location="cpu") + return torch.load(vhead_file, map_location="cpu", weights_only=True) except Exception as err: err_text = str(err) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index c1bd1599..3e351c0f 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -76,7 +76,7 @@ def fix_valuehead_checkpoint( state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} else: path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) - state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True) os.remove(path_to_checkpoint) decoder_state_dict, v_head_state_dict = {}, {}