fix sft trainer

Former-commit-id: df946e6949c77179a5080b780109e22c297caef8
This commit is contained in:
hiyouga 2023-08-09 16:35:03 +08:00
parent 28a807472b
commit b43f37ca19

View File

@ -79,7 +79,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor return padded_tensor.contiguous()
def save_predictions( def save_predictions(
self, self,