fix PPO trainer

Former-commit-id: 1d8a1878ea053d1dbfc570eea868d2514ce75a51
This commit is contained in:
hiyouga 2023-08-02 19:10:23 +08:00
parent 569df8ccd6
commit 4b8e4398bc

View File

@ -161,7 +161,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
unwrapped_model.pretrained_model.generation_config._from_model_config = False unwrapped_model.pretrained_model.generation_config._from_model_config = False
queries, responses = [], [] queries, responses = [], []
query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu() query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1