Merge branch 'hiyouga:main' into pixtral-patch

Former-commit-id: 67f59579d7
This commit is contained in:
Kingsley
2024-10-29 21:01:25 +08:00
committed by GitHub
103 changed files with 1165 additions and 1092 deletions

View File

@@ -72,4 +72,4 @@ def print_env() -> None:
except Exception:
pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")

View File

@@ -75,7 +75,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level

View File

@@ -81,7 +81,7 @@ def check_dependencies() -> None:
else:
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")

View File

@@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
@lru_cache
def is_transformers_version_equal_to_4_46():
return _get_package_version("transformers") == version.parse("4.46.0")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image.
"""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)
for key in keys:
@@ -92,7 +92,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary))
plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step")
plt.ylabel(key)
plt.legend()