mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
Merge commit from fork
This commit is contained in:
parent
7242caf0ff
commit
bb7bf51554
@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||||
return torch.load(vhead_file, map_location="cpu")
|
return torch.load(vhead_file, map_location="cpu", weights_only=True)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
err_text = str(err)
|
err_text = str(err)
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
|
|||||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
else:
|
else:
|
||||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
os.remove(path_to_checkpoint)
|
||||||
decoder_state_dict, v_head_state_dict = {}, {}
|
decoder_state_dict, v_head_state_dict = {}, {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user