mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
Merge pull request #6364 from hiyouga/hiyouga/control_reenterent_gc
[model] support non-reenterent-gc Former-commit-id: a665ad6178516faf8aaa628d3b2c672ad831d7b7
This commit is contained in:
commit
fc18db6290
@ -237,6 +237,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
||||
)
|
||||
use_reentrant_gc: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
|
||||
)
|
||||
upcast_layernorm: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
||||
|
@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
|
||||
)
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc}
|
||||
)
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info_rank0("Gradient checkpointing enabled.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user