fix ppo in trl 0.8.6

This commit is contained in:
hiyouga
2024-06-07 04:48:29 +08:00
parent f9e818d79c
commit 2702d7e952
4 changed files with 37 additions and 25 deletions

View File

@@ -309,12 +309,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
return lr_scheduler
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@torch.no_grad()
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r"""
@@ -326,6 +320,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
@@ -369,19 +364,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return get_rewards_from_server(self.reward_model, messages)
batch = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
reward_model = self.reward_model
with self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1)
@@ -482,3 +477,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self._save(output_dir, state_dict={})
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model.save_checkpoint(output_dir)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)