[trainer] fix batch processing in PPO trainer (#7576)

This commit is contained in:
gechengze 2025-04-02 21:17:48 +08:00 committed by GitHub
parent 903db09822
commit 11997593be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -241,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch = {
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
}
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)