diff --git a/.gitignore b/.gitignore index 40bc040a7..f497809b7 100644 --- a/.gitignore +++ b/.gitignore @@ -85,7 +85,7 @@ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index f860c854f..4a277fd5a 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1624,7 +1624,12 @@ class Qwen3VLPlugin(Qwen2VLPlugin): for video, duration in zip(videos["videos"], videos["durations"]) ] mm_inputs.update( - video_processor(videos=videos["videos"], video_metadata=video_metadata, fps=getattr(processor, "video_fps", 2.0), return_metadata=True) + video_processor( + videos=videos["videos"], + video_metadata=video_metadata, + fps=getattr(processor, "video_fps", 2.0), + return_metadata=True, + ) ) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) if "second_per_grid_ts" in processor.model_input_names: diff --git a/src/llamafactory/train/dpo/ktrainer.py b/src/llamafactory/train/dpo/ktrainer.py index d638d8890..0da2c6851 100644 --- a/src/llamafactory/train/dpo/ktrainer.py +++ b/src/llamafactory/train/dpo/ktrainer.py @@ -15,32 +15,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING + import torch -import torch.nn.functional as F -from transformers import Trainer -from trl import DPOTrainer -from trl.trainer import disable_dropout_in_model +from ktransformers.sft.lora import KTrainer # type: ignore from typing_extensions import override -from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_transformers_version_greater_than -from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach -from .trainer import CustomDPOTrainer as BaseDPOTrainer -from ktransformers.sft.lora import KTrainer +from ..trainer_utils import get_batch_logps, nested_detach +from .trainer import CustomDPOTrainer if TYPE_CHECKING: - from transformers import PreTrainedModel, ProcessorMixin - - from ...hparams import FinetuningArguments + from transformers import PreTrainedModel -class CustomDPOTrainer(KTrainer, BaseDPOTrainer): +class KDPOTrainer(KTrainer, CustomDPOTrainer): @override def concatenated_forward( - self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. @@ -48,9 +40,8 @@ class CustomDPOTrainer(KTrainer, BaseDPOTrainer): """ if self.finetuning_args.use_ref_model: batch = nested_detach(batch, clone=True) # avoid error - labels = batch["labels"] - # dpo not need compute loss in forward, waste mem - del batch["labels"] + + labels = batch.pop("labels") # dpo do not need compute loss in forward all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits = all_logits.to("cpu") labels = labels.to(all_logits.device) @@ -68,4 +59,4 @@ class CustomDPOTrainer(KTrainer, BaseDPOTrainer): if self.loss_type in ["ipo", "orpo", "simpo"]: return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps else: - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length \ No newline at end of file + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index c0ebc301e..acc1c4863 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -218,9 +218,10 @@ class CustomDPOTrainer(DPOTrainer): if self.finetuning_args.use_ref_model: batch = nested_detach(batch, clone=True) # avoid error + labels = batch.pop("labels") # dpo do not need compute loss in forward all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps, valid_length = get_batch_logps( - logits=all_logits, labels=batch["labels"], ld_alpha=(self.ld_alpha if not is_ref_model else None) + logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None) ) if self.loss_type in ["ipo", "orpo", "simpo"]: all_logps = all_logps / valid_length diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 4e3c8f8f3..83ad38dfa 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -61,14 +61,14 @@ def run_dpo( ref_model = create_ref_model(model_args, finetuning_args) else: ref_model = None - - + if model_args.use_kt: - from ktransformers.util.globals import GLOBAL_CONFIG + from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore + + from .ktrainer import KDPOTrainer as CustomDPOTrainer GLOBAL_CONFIG._config["mod"] = "sft" - - from .ktrainer import CustomDPOTrainer + else: from .trainer import CustomDPOTrainer diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index b289f963d..1bf14f1eb 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -73,7 +73,7 @@ def run_sft( raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.") elif finetuning_args.compute_accuracy: raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.") - + if training_args.predict_with_generate: metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer) elif finetuning_args.compute_accuracy: @@ -99,8 +99,8 @@ def run_sft( # Initialize our Trainer if model_args.use_kt: - from ktransformers.util.globals import GLOBAL_CONFIG - from ktransformers.sft.lora import KTrainer + from ktransformers.sft.lora import KTrainer # type: ignore + from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore GLOBAL_CONFIG._config["mod"] = "sft"