use seed in evaluate.py

This commit is contained in:
hiyouga
2023-11-06 18:17:51 +08:00
parent e1e04cb1f1
commit de95b69282

View File

@@ -9,6 +9,7 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
import transformers
from collections import Counter from collections import Counter
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1, n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None,
seed: Optional[int] = 42
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,