mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 05:32:50 +08:00
fix bug in freeze tuning
Former-commit-id: ff52b1779c909819d0aef83d3f7ea663199cbe54
This commit is contained in:
parent
627212e48b
commit
0ed0b8f9c5
@ -37,7 +37,13 @@ def init_adapter(
|
|||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = getattr(model.config, "num_layers")
|
num_layers = (
|
||||||
|
getattr(model.config, "num_hidden_layers", None)
|
||||||
|
or getattr(model.config, "num_layers", None)
|
||||||
|
or getattr(model.config, "n_layer", None)
|
||||||
|
)
|
||||||
|
if not num_layers:
|
||||||
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
|
@ -76,4 +76,5 @@ def create_reward_model(
|
|||||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
|
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
return reward_model
|
return reward_model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user