mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
Merge commit from fork
This commit is contained in:
parent
c1a7f2ebb2
commit
091d2539e8
@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu", weights_only=True)
|
||||||
baichuan2_state_dict.update(shard_weight)
|
baichuan2_state_dict.update(shard_weight)
|
||||||
|
|
||||||
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user