fix incorrect loss value for vlms

Former-commit-id: 30567a1487
This commit is contained in:
hiyouga
2024-10-30 08:56:46 +00:00
parent 1b02915d19
commit 584ce3a105
12 changed files with 48 additions and 22 deletions

View File

@@ -79,8 +79,8 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
@@ -237,7 +237,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download(
@@ -248,7 +248,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
model_args.model_name_or_path,

View File

@@ -81,7 +81,7 @@ def is_transformers_version_greater_than_4_43():
@lru_cache
def is_transformers_version_equal_to_4_46():
return _get_package_version("transformers") == version.parse("4.46.0")
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available():