mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
parent
b983de9f4f
commit
0b7e870b07
@ -97,6 +97,10 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||||
)
|
)
|
||||||
|
offload_folder: str = field(
|
||||||
|
default="offload",
|
||||||
|
metadata={"help": "Path to offload model weights."},
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
|
@ -291,6 +291,9 @@ def patch_config(
|
|||||||
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||||
|
|
||||||
|
if init_kwargs["device_map"] == "auto":
|
||||||
|
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||||
|
Loading…
x
Reference in New Issue
Block a user