use seed in evaluate.py

Former-commit-id: de95b69282
This commit is contained in:
hiyouga
2023-11-06 18:17:51 +08:00
parent 2627f95ac3
commit 5c19786f7c

View File

@@ -9,6 +9,7 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
import transformers
from collections import Counter from collections import Counter
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1, n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None,
seed: Optional[int] = 42
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,