remove gc warnings in DPO&KTO

Former-commit-id: f9a206509e
This commit is contained in:
hiyouga
2024-06-03 22:53:54 +08:00
parent b2c224de69
commit 6f7b6ae0c3
3 changed files with 20 additions and 6 deletions

View File

@@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING:
@@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
@@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer):
"""
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
ref_context = get_ref_context(self.accelerator, model)
else:
ref_model = self.ref_model
ref_context = nullcontext()