mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 05:32:50 +08:00
parent
f865d0bd51
commit
85aa16f6c6
@ -50,9 +50,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
generated_tokens = (
|
if generated_tokens is not None:
|
||||||
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
generated_tokens[:, :max(prompt_len, label_len)] = (
|
||||||
|
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
|
||||||
)
|
)
|
||||||
|
generated_tokens = generated_tokens.contiguous()
|
||||||
|
|
||||||
return loss, generated_tokens, labels
|
return loss, generated_tokens, labels
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user