[FEATURE]: ADD LORA+ ALGORITHM

This commit is contained in:
齐保元
2024-03-13 19:43:27 +08:00
parent dfd451b722
commit a0965cd62c
4 changed files with 130 additions and 3 deletions

View File

@@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
if TYPE_CHECKING:
@@ -51,6 +51,8 @@ def run_sft(
# Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if finetuning_args.lora_lr_ratio:
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,