[v1] add LoRA/Freeze support and merge workflow (#10157)

This commit is contained in:
jiaqiw09
2026-02-12 13:02:09 +08:00
committed by GitHub
parent 184304b5b4
commit ab073f4c13
9 changed files with 577 additions and 12 deletions

View File

@@ -204,6 +204,16 @@ class BaseTrainer:
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
state_dict = None
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
if DistributedInterface().get_rank() != 0:
return
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")