diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 5f0ceaae..a6c881ac 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -143,6 +143,9 @@ class MMPluginMixin: ) -> None: r"""Validate if this model accepts the input modalities.""" image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "image_processor", None) + ) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) if len(images) != 0 and self.image_token is None: raise ValueError( @@ -165,6 +168,9 @@ class MMPluginMixin: if self.image_token is not None and image_processor is None: raise ValueError("Image processor was not found, please check and update your processor config.") + if self.video_token is not None and video_processor is None: + raise ValueError("Video processor was not found, please check and update your processor config.") + if self.audio_token is not None and feature_extractor is None: raise ValueError("Audio feature extractor was not found, please check and update your processor config.") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 7d5aad9f..80f67c6c 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer): return super()._get_train_sampler() @override - def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs): + def get_batch_samples(self, *args, **kwargs): r"""Replace the method of DPO Trainer with the one of the standard Trainer.""" - return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs) + return Trainer.get_batch_samples(self, *args, **kwargs) def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.""" diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 5f620b18..c34b4850 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer): return Trainer._get_train_sampler(self) @override - def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs): + def get_batch_samples(self, *args, **kwargs): r"""Replace the method of KTO Trainer with the one of the standard Trainer.""" - return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs) + return Trainer.get_batch_samples(self, *args, **kwargs) @override def forward( diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 346611f9..faf94fc8 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -70,3 +70,7 @@ class CustomTrainer(Trainer): return torch.utils.data.SequentialSampler(self.train_dataset) return super()._get_train_sampler() + + @override + def compute_loss(self, model, inputs, *args, **kwargs): + return super().compute_loss(model, inputs, *args, **kwargs) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index f90edeab..cd09768b 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -59,6 +59,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer") super().__init__(**kwargs) + if processor is not None: + self.model_accepts_loss_kwargs = False + self.finetuning_args = finetuning_args if gen_kwargs is not None: # https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287 @@ -93,6 +96,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): return super()._get_train_sampler() + @override + def compute_loss(self, model, inputs, *args, **kwargs): + return super().compute_loss(model, inputs, *args, **kwargs) + @override def prediction_step( self,