Former-commit-id: 815b92e698562bfae6eb9a6fa1b612a05d43ed67
This commit is contained in:
hiyouga 2023-09-10 14:22:03 +08:00
parent f865d0bd51
commit 85aa16f6c6

View File

@ -50,9 +50,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
generated_tokens = (
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
)
if generated_tokens is not None:
generated_tokens[:, :max(prompt_len, label_len)] = (
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
)
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels