use seed in evaluate.py

Former-commit-id: ab5cac1dfa681933f3266827f80068ce798b4c56
This commit is contained in:
hiyouga 2023-11-06 18:17:51 +08:00
parent ba3e8ba20c
commit 21ac46e439

View File

@ -9,6 +9,7 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
import transformers
from collections import Counter from collections import Counter
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1, n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None,
seed: Optional[int] = 42
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,