diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 132df189..97d80353 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -36,6 +36,7 @@ class CustomDPOTrainer(DPOTrainer): self.precompute_ref_log_probs = False self._precomputed_train_ref_log_probs = False self._precomputed_eval_ref_log_probs = False + self._peft_has_been_casted_to_bf16 = False self.ref_model = ref_model self.beta = beta