mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
parent
377dfe5665
commit
3bcb4633ca
@ -121,5 +121,6 @@ def main():
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown command: {command}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -105,7 +105,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
|
||||
)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
@ -116,6 +119,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
|
||||
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
|
@ -79,6 +79,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.simpo_gamma = finetuning_args.simpo_gamma
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
@ -274,15 +275,14 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs)
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
|
@ -77,6 +77,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
@ -252,15 +253,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs)
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
|
@ -19,7 +19,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
@ -78,15 +78,13 @@ class CustomTrainer(Trainer):
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
# other model should not scale the loss
|
||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
@ -52,6 +52,7 @@ class PairwiseTrainer(Trainer):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
self.finetuning_args = finetuning_args
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
self.add_callback(FixValueHeadModelCallback)
|
||||
@ -107,8 +108,8 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
|
||||
|
||||
if return_outputs:
|
||||
return loss, (loss, chosen_scores, rejected_scores)
|
||||
|
@ -27,7 +27,7 @@ from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
@ -93,16 +93,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
# other model should not scale the loss
|
||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user