mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
Merge commit from fork
This commit is contained in:
parent
7242caf0ff
commit
bb7bf51554
@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
return torch.load(vhead_file, map_location="cpu", weights_only=True)
|
||||
except Exception as err:
|
||||
err_text = str(err)
|
||||
|
||||
|
@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
|
||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
decoder_state_dict, v_head_state_dict = {}, {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user