diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 94cc9fce..ef5bd1dc 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -121,5 +121,6 @@ def main(): else: raise NotImplementedError(f"Unknown command: {command}.") + if __name__ == "__main__": main() diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 3b044443..e1c0f247 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -105,7 +105,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) - fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) + fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) + fake_input_ids, _ = self.template.mm_plugin.process_token_ids( + fake_input_ids, None, fake_images, [], self.tokenizer, self.processor + ) if self.tokenizer.padding_side == "right": features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) @@ -116,6 +119,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"] batch_images = fake_images + batch_imglens[0] = 1 batch_input_ids[0] = features[0]["input_ids"] mm_inputs = self.template.mm_plugin.get_mm_inputs( diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 9d1e0104..59862443 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -79,6 +79,7 @@ class CustomDPOTrainer(DPOTrainer): self.simpo_gamma = finetuning_args.simpo_gamma Trainer.__init__(self, model=model, **kwargs) + self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") @@ -274,15 +275,14 @@ class CustomDPOTrainer(DPOTrainer): self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" - Fixes the loss value for transformers 4.46.0. - https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. """ loss = super().compute_loss(model, inputs, return_outputs) - if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): + if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"): if return_outputs: - return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) + loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) else: - return loss / self.args.gradient_accumulation_steps + loss = loss / self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3d007ae7..f60c2366 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -77,6 +77,7 @@ class CustomKTOTrainer(KTOTrainer): self.ftx_gamma = finetuning_args.pref_ftx Trainer.__init__(self, model=model, **kwargs) + self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") @@ -252,15 +253,14 @@ class CustomKTOTrainer(KTOTrainer): self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" - Fixes the loss value for transformers 4.46.0. - https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. """ loss = super().compute_loss(model, inputs, return_outputs) - if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): + if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"): if return_outputs: - return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) + loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) else: - return loss / self.args.gradient_accumulation_steps + loss = loss / self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 445462b9..11d91111 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -19,7 +19,7 @@ import torch from transformers import Trainer from typing_extensions import override -from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than +from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -78,15 +78,13 @@ class CustomTrainer(Trainer): self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" - Fixes the loss value for transformers 4.46.0. - https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. """ loss = super().compute_loss(model, inputs, return_outputs, **kwargs) - if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): - # other model should not scale the loss + if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): if return_outputs: - return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) + loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) else: - return loss / self.args.gradient_accumulation_steps + loss = loss / self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 347cae9b..574b87b2 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -52,6 +52,7 @@ class PairwiseTrainer(Trainer): kwargs["processing_class"] = kwargs.pop("tokenizer") super().__init__(**kwargs) + self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior self.finetuning_args = finetuning_args self.can_return_loss = True # override property to return eval_loss self.add_callback(FixValueHeadModelCallback) @@ -107,8 +108,8 @@ class PairwiseTrainer(Trainer): loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() - if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): - loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0 + if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"): + loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1 if return_outputs: return loss, (loss, chosen_scores, rejected_scores) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 6ba758cd..45998262 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,7 +27,7 @@ from typing_extensions import override from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than +from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -93,16 +93,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" - Fixes the loss value for transformers 4.46.0. - https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. """ loss = super().compute_loss(model, inputs, return_outputs, **kwargs) - if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): - # other model should not scale the loss + if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): if return_outputs: - return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) + loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) else: - return loss / self.args.gradient_accumulation_steps + loss = loss / self.args.gradient_accumulation_steps return loss