mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
huggingface login for projects must login while running
Former-commit-id: 3e000c2b60c2e29bcafcf8d39c1a5d567ae2491c
This commit is contained in:
parent
0dc9b41b16
commit
7ffd961b8b
@ -16,3 +16,4 @@ pydantic==1.10.11
|
|||||||
fastapi==0.95.1
|
fastapi==0.95.1
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
|
huggingface_hub
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from huggingface_hub.hf_api import HfFolder
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -63,6 +64,11 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
)
|
)
|
||||||
|
hf_hub_token : Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
@ -70,3 +76,6 @@ class ModelArguments:
|
|||||||
|
|
||||||
if self.quantization_bit is not None:
|
if self.quantization_bit is not None:
|
||||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
|
if self.use_auth_token == True and self.hf_hub_token != None:
|
||||||
|
HfFolder.save_token(self.hf_hub_token)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user