mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-21 22:58:58 +08:00
[feat] support ktransformers for dpo (#9621)
Co-authored-by: poryfly <porykid@gmail.com>
This commit is contained in:
@@ -24,7 +24,6 @@ from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -62,6 +61,16 @@ def run_dpo(
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
|
||||
if model_args.use_kt:
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
from .ktrainer import CustomDPOTrainer
|
||||
else:
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
|
||||
Reference in New Issue
Block a user