ppo support rm server

This commit is contained in:
hiyouga
2023-12-03 21:38:51 +08:00
parent 7df4f3ab20
commit 747db40172
5 changed files with 47 additions and 15 deletions

View File

@@ -76,7 +76,9 @@ def create_reward_model(
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "api":
raise NotImplementedError
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model))
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
@@ -102,6 +104,6 @@ def create_reward_model(
reward_model, _ = load_model_and_tokenizer(
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model