improve lora+ impl.

This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 4e5e99af43
commit 72367307df
12 changed files with 165 additions and 169 deletions

View File

@@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
from ..utils import create_custom_optimzer
if TYPE_CHECKING:
@@ -51,8 +51,6 @@ def run_sft(
# Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if finetuning_args.lora_lr_ratio:
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,