diff --git a/src/evaluate.py b/src/evaluate.py index d7511f6b..8af8c12c 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -9,6 +9,7 @@ import fire import json import torch import numpy as np +import transformers from collections import Counter from datasets import load_dataset from dataclasses import dataclass @@ -111,11 +112,13 @@ def evaluate( n_shot: Optional[int] = 5, n_avg: Optional[int] = 1, batch_size: Optional[int] = 4, - save_name: Optional[str] = None + save_name: Optional[str] = None, + seed: Optional[int] = 42 ): with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: categorys: Dict[str, Dict[str, str]] = json.load(f) + transformers.set_seed(seed) chat_model = ChatModel(dict( model_name_or_path=model_name_or_path, finetuning_type=finetuning_type,