Former-commit-id: 87390ae3b70f654d520b9aadb335c9650130a42c
This commit is contained in:
hiyouga 2023-11-13 22:42:23 +08:00
parent 125587b187
commit 37db26800c

View File

@ -5,7 +5,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer from trl import PPOTrainer
@ -108,9 +108,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval() self.model.eval()
# Get inputs # Get inputs
queries, responses = self.get_inputs(batch)
self.tokenizer.padding_side = "right" # change padding side self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model) queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode # Cast to training mode
unwrapped_model.gradient_checkpointing_enable() unwrapped_model.gradient_checkpointing_enable()
@ -165,7 +170,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
) )
@torch.no_grad() @torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
""" """
@ -219,7 +224,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = [] rewards = []
for i in range(values.size(0)): for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero() end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0 end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type