mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-20 12:02:48 +08:00
Compare commits
No commits in common. "b2395b25b007898f994d8bed101e0d6263de51b8" and "d8e4482849e3c0ea2df30e7d2b2013268b693053" have entirely different histories.
b2395b25b0
...
d8e4482849
@ -19,7 +19,6 @@ sentencepiece
|
|||||||
tiktoken
|
tiktoken
|
||||||
modelscope>=1.14.0
|
modelscope>=1.14.0
|
||||||
hf-transfer
|
hf-transfer
|
||||||
safetensors<=0.5.3
|
|
||||||
# python
|
# python
|
||||||
fire
|
fire
|
||||||
omegaconf
|
omegaconf
|
||||||
|
@ -35,7 +35,7 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: str = field(
|
dataset_dir: Union[str, dict] = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
|
@ -73,7 +73,7 @@ def fix_valuehead_checkpoint(
|
|||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key).clone() for key in f.keys()}
|
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
else:
|
else:
|
||||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
|
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user