mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-05 21:42:20 +08:00
[trainer] add dpo/kto fsdp fsdp2 support (#10127)
This commit is contained in:
@@ -25,8 +25,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -97,6 +97,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
if self.accelerator.is_fsdp2:
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||
else:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -99,6 +99,13 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
if self.accelerator.is_fsdp2:
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||
else:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
Reference in New Issue
Block a user