diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 1684fb17..b285a003 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -241,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.tokenizer.padding_side = "right" # change padding side queries, responses, rewards = [], [], [] for idx in range(0, self.config.batch_size, self.config.mini_batch_size): - mini_batch_queries, mini_batch_responses = self.get_inputs( - batch[idx : idx + self.config.mini_batch_size] - ) + mini_batch = { + "input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size], + "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size] + } + mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch) mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses) queries.extend(mini_batch_queries) responses.extend(mini_batch_responses)