mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[trainer] fix batch processing in PPO trainer (#7576)
This commit is contained in:
parent
903db09822
commit
11997593be
@ -241,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
queries, responses, rewards = [], [], []
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
||||
batch[idx : idx + self.config.mini_batch_size]
|
||||
)
|
||||
mini_batch = {
|
||||
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
|
||||
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
|
||||
}
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
||||
queries.extend(mini_batch_queries)
|
||||
responses.extend(mini_batch_responses)
|
||||
|
Loading…
x
Reference in New Issue
Block a user