mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
update ppo trainer
Former-commit-id: b5ba87952ab02ed0720365ebd571e47e92e1cda6
This commit is contained in:
parent
ab739e72ea
commit
569df8ccd6
@ -47,7 +47,6 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
model = unwrap_model(self.model)
|
model = unwrap_model(self.model)
|
||||||
|
|
||||||
if isinstance(model, PreTrainedModelWrapper):
|
if isinstance(model, PreTrainedModelWrapper):
|
||||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||||
model_state_dict = state_dict or model.state_dict()
|
model_state_dict = state_dict or model.state_dict()
|
||||||
|
@ -2,10 +2,9 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import TrainerState, TrainerControl
|
from transformers import TrainerState, TrainerControl
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
@ -18,6 +17,7 @@ from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
@ -43,7 +43,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
self.log_callback = callbacks[0]
|
self.log_callback = callbacks[0]
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
|
||||||
self._remove_log()
|
self._remove_log()
|
||||||
|
|
||||||
def ppo_train(self, max_target_length: int) -> None:
|
def ppo_train(self, max_target_length: int) -> None:
|
||||||
@ -83,7 +82,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
"logits_processor": get_logits_processor()
|
"logits_processor": get_logits_processor()
|
||||||
}
|
}
|
||||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||||
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
@ -95,38 +94,22 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
batch = next(dataiter)
|
batch = next(dataiter)
|
||||||
steps_trained += 1
|
steps_trained += 1
|
||||||
|
|
||||||
|
# Cast to inference mode
|
||||||
unwrapped_model.gradient_checkpointing_disable()
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
unwrapped_model.config.use_cache = True
|
unwrapped_model.config.use_cache = True
|
||||||
|
unwrapped_model.eval()
|
||||||
|
|
||||||
# Get responses
|
# Get inputs
|
||||||
query_tensors = batch["input_ids"]
|
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
||||||
response_tensors = self.generate(
|
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||||
batch, length_sampler, return_prompt=False, **gen_kwargs
|
|
||||||
).detach().cpu() # move to cpu
|
|
||||||
|
|
||||||
queries, responses = [], []
|
# Cast to training mode
|
||||||
for i in range(len(query_tensors)):
|
|
||||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
|
||||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
|
||||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
|
||||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
|
||||||
|
|
||||||
# Compute rewards
|
|
||||||
replace_model(unwrapped_model, target="reward")
|
|
||||||
with torch.no_grad():
|
|
||||||
_, _, values: torch.Tensor = self.model(
|
|
||||||
**self.prepare_model_inputs(queries, responses),
|
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict=True
|
|
||||||
)
|
|
||||||
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
|
||||||
replace_model(unwrapped_model, target="default")
|
|
||||||
|
|
||||||
# Run PPO step
|
|
||||||
unwrapped_model.gradient_checkpointing_enable()
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
unwrapped_model.config.use_cache = False
|
unwrapped_model.config.use_cache = False
|
||||||
stats = self.step(queries, responses, rewards)
|
unwrapped_model.train()
|
||||||
|
|
||||||
|
# Run PPO step
|
||||||
|
stats = self.step(queries, responses, rewards)
|
||||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
@ -155,37 +138,57 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, torch.Tensor],
|
inputs: Dict[str, torch.Tensor],
|
||||||
length_sampler: Optional[Callable] = None,
|
length_sampler: Optional[Callable] = None,
|
||||||
return_prompt: Optional[bool] = True,
|
|
||||||
**generation_kwargs
|
**generation_kwargs
|
||||||
) -> torch.Tensor:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
|
||||||
"""
|
"""
|
||||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
|
||||||
|
|
||||||
if length_sampler is not None:
|
if length_sampler is not None:
|
||||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||||
|
|
||||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||||
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response = unwrapped_model.generate(**inputs, **generation_kwargs)
|
response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs)
|
||||||
|
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||||
|
|
||||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||||
|
|
||||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
queries, responses = [], []
|
||||||
|
query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
||||||
|
for i in range(len(query)):
|
||||||
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||||
|
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||||
|
queries.append(query[i, query_length:]) # remove padding from left
|
||||||
|
responses.append(response[i, :response_length]) # remove padding from right
|
||||||
|
|
||||||
if not return_prompt and not self.is_encoder_decoder:
|
return queries, responses
|
||||||
return response[:, inputs["input_ids"].size(1):]
|
|
||||||
return response
|
@torch.no_grad()
|
||||||
|
def get_rewards(
|
||||||
|
self,
|
||||||
|
queries: List[torch.Tensor],
|
||||||
|
responses: List[torch.Tensor],
|
||||||
|
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Computes scores using given reward model.
|
||||||
|
"""
|
||||||
|
replace_model(unwrapped_model, target="reward")
|
||||||
|
_, _, values = self.model(
|
||||||
|
**self.prepare_model_inputs(queries, responses),
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True
|
||||||
|
)
|
||||||
|
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||||
|
replace_model(unwrapped_model, target="default")
|
||||||
|
return rewards
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user