mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
fix ppo trainer #551
Former-commit-id: 0676497104eccc8a737d27890eabf1ca8713c235
This commit is contained in:
parent
8f9f618bcc
commit
570ccc3618
@ -12,7 +12,7 @@ from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
|||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import replace_model
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
@ -152,8 +152,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
if length_sampler is not None:
|
if length_sampler is not None:
|
||||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||||
|
|
||||||
|
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||||
|
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
||||||
|
|
||||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Literal
|
import torch
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
@ -15,3 +18,23 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
|||||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def cast_layernorm_dtype(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
compute_dtype: torch.dtype,
|
||||||
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
|
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
layer_norm_state_dict = {}
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
|
if layer_norm_params is None:
|
||||||
|
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||||
|
param.data = param.data.to(compute_dtype)
|
||||||
|
else:
|
||||||
|
param.data = layer_norm_params[name] # restore float32 weights
|
||||||
|
|
||||||
|
return model, layer_norm_state_dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user