mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 08:08:09 +08:00
fix PPO trainer
Former-commit-id: 21982a7d4dd9b7c3a1145b481f02b9990e32dc00
This commit is contained in:
parent
e4d0b8ee6e
commit
8bd1da7144
@ -161,7 +161,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||||
|
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||||
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user