fix ppo in trl 0.8.6

Former-commit-id: 2702d7e952523b584d67c8901888b492d4a79b14
This commit is contained in:
hiyouga 2024-06-07 04:48:29 +08:00
parent d3196318be
commit f76d427332
4 changed files with 37 additions and 25 deletions

View File

@ -298,7 +298,7 @@ huggingface-cli login
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.6 | 0.9.3 | | trl | 0.8.6 | 0.9.4 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |

View File

@ -298,7 +298,7 @@ huggingface-cli login
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.6 | 0.9.3 | | trl | 0.8.6 | 0.9.4 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |

View File

@ -1,7 +1,9 @@
import json import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available from ...extras.packages import is_requests_available
@ -28,16 +30,27 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
r""" r"""
Replaces the default/reward modules in the model. The model is already unwrapped (and gathered). Replaces the default/reward modules in the model. The model is already unwrapped.
""" """
if target == "reward": # save default head temporarily if is_deepspeed_zero3_enabled():
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone()) import deepspeed # type: ignore
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
params = [model.v_head.summary.weight, model.v_head.summary.bias]
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
device = model.v_head.summary.weight.device with context_maybe_zero3:
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) if target == "reward": # save default head temporarily
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
device = model.v_head.summary.weight.device
model.v_head.summary.weight.data = (
model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
)
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:

View File

@ -309,12 +309,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
) )
return lr_scheduler return lr_scheduler
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@torch.no_grad() @torch.no_grad()
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r""" r"""
@ -326,6 +320,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch[k] = v[:, start_index:] batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
if self.model_args.upcast_layernorm: if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model) layernorm_params = dump_layernorm(unwrapped_model)
@ -369,19 +364,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return get_rewards_from_server(self.reward_model, messages) return get_rewards_from_server(self.reward_model, messages)
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: if self.finetuning_args.reward_model_type == "lora":
if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="reward")
replace_model(unwrapped_model, target="reward") reward_model = self.model
reward_model = self.model else:
else: reward_model = self.reward_model
reward_model = self.reward_model
with self.amp_context: # support bf16 with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
if self.is_chatglm_model: # assume same architecture if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
@ -482,3 +477,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self._save(output_dir, state_dict={}) self._save(output_dir, state_dict={})
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model.save_checkpoint(output_dir) self.model.save_checkpoint(output_dir)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)