From 11997593be227b3bbabb3fb2a13f7751c9627b08 Mon Sep 17 00:00:00 2001 From: gechengze <34020535+gechengze@users.noreply.github.com> Date: Wed, 2 Apr 2025 21:17:48 +0800 Subject: [PATCH] [trainer] fix batch processing in PPO trainer (#7576) --- src/llamafactory/train/ppo/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 1684fb17..b285a003 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -241,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.tokenizer.padding_side = "right" # change padding side queries, responses, rewards = [], [], [] for idx in range(0, self.config.batch_size, self.config.mini_batch_size): - mini_batch_queries, mini_batch_responses = self.get_inputs( - batch[idx : idx + self.config.mini_batch_size] - ) + mini_batch = { + "input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size], + "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size] + } + mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch) mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses) queries.extend(mini_batch_queries) responses.extend(mini_batch_responses)