diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index 4a22d5af..94dd11a6 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -105,11 +105,13 @@ class PeftTrainer(Seq2SeqTrainer): if self.finetuning_args.finetuning_type == "lora": backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) else: # freeze/full tuning + backbone_model.config.use_cache = True backbone_model.save_pretrained( output_dir, state_dict=get_state_dict(backbone_model), safe_serialization=self.args.save_safetensors ) + backbone_model.config.use_cache = False if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)