From df722bf18e8ce0484054edb1cbbdee2eb33195f3 Mon Sep 17 00:00:00 2001 From: KUANGDD Date: Mon, 14 Oct 2024 21:11:09 +0800 Subject: [PATCH] required transformers version Former-commit-id: 9f44598b92e72cf8dd923eb229f4637ab9287948 --- src/llamafactory/data/mm_plugin.py | 6 ++++++ src/llamafactory/extras/misc.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b9e7bc3b..9d81848b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -692,6 +692,12 @@ def get_mm_plugin( video_token: Optional[str] = None, ) -> "BasePlugin": plugin_class = PLUGINS.get(name, None) + if plugin_class == "PixtralPlugin": + from transformers.utils.versions import require_version + try: + require_version("transformers==4.46.0.dev0") + except Exception as e: + raise ImportError("PixtralPlugin requires transformers==4.46.0.dev0. Please install it first.") if plugin_class is None: raise ValueError("Multimodal plugin `{}` not found.".format(name)) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index fd78530a..47f2ebbe 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,7 +79,7 @@ def check_dependencies() -> None: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: - require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2") + require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2,<=4.45.2") require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")