diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py index 137c6b7d..7409a22e 100644 --- a/src/llamafactory/model/model_utils/valuehead.py +++ b/src/llamafactory/model/model_utils/valuehead.py @@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> try: vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) - return torch.load(vhead_file, map_location="cpu") + return torch.load(vhead_file, map_location="cpu", weights_only=True) except Exception as err: err_text = str(err) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index c1bd1599..3e351c0f 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -76,7 +76,7 @@ def fix_valuehead_checkpoint( state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} else: path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) - state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True) os.remove(path_to_checkpoint) decoder_state_dict, v_head_state_dict = {}, {}