mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
[trainer] fix batch processing in PPO trainer (#7576)
This commit is contained in:
parent
903db09822
commit
11997593be
@ -241,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
queries, responses, rewards = [], [], []
|
queries, responses, rewards = [], [], []
|
||||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||||
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
mini_batch = {
|
||||||
batch[idx : idx + self.config.mini_batch_size]
|
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
|
||||||
)
|
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
|
||||||
|
}
|
||||||
|
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
|
||||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
||||||
queries.extend(mini_batch_queries)
|
queries.extend(mini_batch_queries)
|
||||||
responses.extend(mini_batch_responses)
|
responses.extend(mini_batch_responses)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user