mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix ppo in trl 0.8.6
Former-commit-id: 2702d7e952523b584d67c8901888b492d4a79b14
This commit is contained in:
parent
d3196318be
commit
f76d427332
@ -298,7 +298,7 @@ huggingface-cli login
|
|||||||
| datasets | 2.16.0 | 2.19.2 |
|
| datasets | 2.16.0 | 2.19.2 |
|
||||||
| accelerate | 0.30.1 | 0.30.1 |
|
| accelerate | 0.30.1 | 0.30.1 |
|
||||||
| peft | 0.11.1 | 0.11.1 |
|
| peft | 0.11.1 | 0.11.1 |
|
||||||
| trl | 0.8.6 | 0.9.3 |
|
| trl | 0.8.6 | 0.9.4 |
|
||||||
|
|
||||||
| Optional | Minimum | Recommend |
|
| Optional | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
|
@ -298,7 +298,7 @@ huggingface-cli login
|
|||||||
| datasets | 2.16.0 | 2.19.2 |
|
| datasets | 2.16.0 | 2.19.2 |
|
||||||
| accelerate | 0.30.1 | 0.30.1 |
|
| accelerate | 0.30.1 | 0.30.1 |
|
||||||
| peft | 0.11.1 | 0.11.1 |
|
| peft | 0.11.1 | 0.11.1 |
|
||||||
| trl | 0.8.6 | 0.9.3 |
|
| trl | 0.8.6 | 0.9.4 |
|
||||||
|
|
||||||
| 可选项 | 至少 | 推荐 |
|
| 可选项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.packages import is_requests_available
|
from ...extras.packages import is_requests_available
|
||||||
|
|
||||||
@ -28,16 +30,27 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
|||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
r"""
|
r"""
|
||||||
Replaces the default/reward modules in the model. The model is already unwrapped (and gathered).
|
Replaces the default/reward modules in the model. The model is already unwrapped.
|
||||||
"""
|
"""
|
||||||
if target == "reward": # save default head temporarily
|
if is_deepspeed_zero3_enabled():
|
||||||
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
import deepspeed # type: ignore
|
||||||
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
|
||||||
|
params = [model.v_head.summary.weight, model.v_head.summary.bias]
|
||||||
|
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||||
|
else:
|
||||||
|
context_maybe_zero3 = nullcontext()
|
||||||
|
|
||||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||||
device = model.v_head.summary.weight.device
|
with context_maybe_zero3:
|
||||||
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
if target == "reward": # save default head temporarily
|
||||||
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||||
|
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||||
|
|
||||||
|
device = model.v_head.summary.weight.device
|
||||||
|
model.v_head.summary.weight.data = (
|
||||||
|
model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
||||||
|
)
|
||||||
|
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||||
|
|
||||||
|
|
||||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||||
|
@ -309,12 +309,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
)
|
)
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
|
||||||
super()._save(output_dir, state_dict)
|
|
||||||
if self.processor is not None:
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
||||||
r"""
|
r"""
|
||||||
@ -326,6 +320,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
batch[k] = v[:, start_index:]
|
batch[k] = v[:, start_index:]
|
||||||
|
|
||||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||||
|
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
|
||||||
if self.model_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
layernorm_params = dump_layernorm(unwrapped_model)
|
layernorm_params = dump_layernorm(unwrapped_model)
|
||||||
|
|
||||||
@ -369,19 +364,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
return get_rewards_from_server(self.reward_model, messages)
|
return get_rewards_from_server(self.reward_model, messages)
|
||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
if self.finetuning_args.reward_model_type == "lora":
|
replace_model(unwrapped_model, target="reward")
|
||||||
replace_model(unwrapped_model, target="reward")
|
reward_model = self.model
|
||||||
reward_model = self.model
|
else:
|
||||||
else:
|
reward_model = self.reward_model
|
||||||
reward_model = self.reward_model
|
|
||||||
|
|
||||||
with self.amp_context: # support bf16
|
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||||
|
|
||||||
if self.finetuning_args.reward_model_type == "lora":
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
if self.is_chatglm_model: # assume same architecture
|
if self.is_chatglm_model: # assume same architecture
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
@ -482,3 +477,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self._save(output_dir, state_dict={})
|
self._save(output_dir, state_dict={})
|
||||||
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||||
self.model.save_checkpoint(output_dir)
|
self.model.save_checkpoint(output_dir)
|
||||||
|
|
||||||
|
if self.processor is not None:
|
||||||
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
|
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user