mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-28 02:39:03 +08:00
[v1] add LoRA/Freeze support and merge workflow (#10157)
This commit is contained in:
@@ -204,6 +204,16 @@ class BaseTrainer:
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir)
|
||||
state_dict = None
|
||||
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
|
||||
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
state_dict = get_model_state_dict(self.model, options=options)
|
||||
|
||||
if DistributedInterface().get_rank() != 0:
|
||||
return
|
||||
|
||||
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
Reference in New Issue
Block a user