diff --git a/setup.py b/setup.py index bf7662c8..6a7c2791 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ extra_require = { "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "vllm": ["vllm>=0.4.3,<0.6.5"], + "vllm": ["vllm>=0.4.3,<0.6.7"], "galore": ["galore-torch"], "badam": ["badam>=1.2.1"], "adam-mini": ["adam-mini"], diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index d001386b..1004544d 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -171,7 +171,7 @@ class HuggingfaceEngine(BaseEngine): elif not isinstance(value, torch.Tensor): value = torch.tensor(value) - if torch.is_floating_point(value): + if torch.is_floating_point(value): # cast data dtype for paligemma value = value.to(model.dtype) gen_kwargs[key] = value.to(model.device) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index e1c0f247..f360f0f5 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -168,6 +168,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): if self.block_diag_attn and self.attn_implementation != "flash_attention_2": features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + for key, value in features.items(): # cast data dtype for paligemma + if torch.is_tensor(value) and torch.is_floating_point(value): + features[key] = value.to(self.compute_dtype) + return features diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index ae0da5ee..735c5d63 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,12 +79,13 @@ def check_dependencies() -> None: """ if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") - else: - require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") - require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0") - require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") - require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") - require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") + return + + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") + require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0") + require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") + require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") + require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 5b542c2e..4a254367 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -112,6 +112,10 @@ def _check_extra_dependencies( finetuning_args: "FinetuningArguments", training_args: Optional["Seq2SeqTrainingArguments"] = None, ) -> None: + if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: + logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") + return + if model_args.use_unsloth: require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") @@ -122,7 +126,7 @@ def _check_extra_dependencies( require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5") + require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7") if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 59862443..ad670385 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model from typing_extensions import override from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_transformers_version_equal_to_4_46 +from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps @@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer): disable_dropout: bool = True, **kwargs, ): + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + if disable_dropout: disable_dropout_in_model(model) if ref_model is not None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index f60c2366..3c6d1089 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model from typing_extensions import override from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_transformers_version_equal_to_4_46 +from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps @@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer): disable_dropout: bool = True, **kwargs, ): + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + if disable_dropout: disable_dropout_in_model(model) if ref_model is not None: