mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
fix eval and pred loss
Former-commit-id: 2a5a8e0eba279de603c2d25e894b6d2921aaae55
This commit is contained in:
parent
961e6a9ba4
commit
fa06b168ab
@ -79,11 +79,13 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
input_ids = inputs["input_ids"]
|
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||||
|
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None
|
generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None
|
||||||
|
|
||||||
return (loss, generated_tokens, labels)
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user