mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
fix ppo in trl 0.8.6
This commit is contained in:
@@ -309,12 +309,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
||||
r"""
|
||||
@@ -326,6 +320,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
@@ -369,19 +364,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
with self.amp_context: # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
if self.is_chatglm_model: # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
@@ -482,3 +477,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self._save(output_dir, state_dict={})
|
||||
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
self.model.save_checkpoint(output_dir)
|
||||
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
Reference in New Issue
Block a user