This commit is contained in:
hiyouga
2023-09-08 20:22:18 +08:00
parent 8ea32e4046
commit b34797a845
3 changed files with 5 additions and 3 deletions

View File

@@ -175,6 +175,7 @@ def load_model_and_tokenizer(
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":