use seed in evaluate.py

Former-commit-id: de95b6928293e79c7e204be307c1784ce146c1b1
This commit is contained in:
hiyouga 2023-11-06 18:17:51 +08:00
parent 2627f95ac3
commit 5c19786f7c

View File

@ -9,6 +9,7 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
import transformers
from collections import Counter from collections import Counter
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1, n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None,
seed: Optional[int] = 42
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,