use seed in evaluate.py

Former-commit-id: ab5cac1dfa681933f3266827f80068ce798b4c56
This commit is contained in:
hiyouga 2023-11-06 18:17:51 +08:00
parent 8a79b9938c
commit 441d9ae0ef

View File

@ -9,6 +9,7 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
import transformers
from collections import Counter from collections import Counter
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1, n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None,
seed: Optional[int] = 42
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,