mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 67f59579d7
This commit is contained in:
@@ -72,4 +72,4 @@ def print_env() -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
||||
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|
||||
|
||||
@@ -75,7 +75,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
return logging._nameToLevel[env_level_str.upper()]
|
||||
else:
|
||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
||||
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||
|
||||
return _default_log_level
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ def check_dependencies() -> None:
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
|
||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
||||
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
@@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
|
||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_equal_to_4_46():
|
||||
return _get_package_version("transformers") == version.parse("4.46.0")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
@@ -92,7 +92,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.title(f"training {key} of {save_dictionary}")
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
|
||||
Reference in New Issue
Block a user