fix sub-process error in thread

Former-commit-id: 9c10854b46fc67d45ff9fc2f62af4d6af826f9c1
This commit is contained in:
hiyouga 2024-03-03 15:04:35 +08:00
parent d966aee105
commit 0e58cd6422
2 changed files with 5 additions and 4 deletions

View File

@ -46,8 +46,8 @@ def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]
continue
plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed")
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary))
plt.xlabel("step")
plt.ylabel(key)

View File

@ -82,6 +82,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
return
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
@ -95,8 +96,8 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
if len(losses) == 0:
return None
ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed")
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")