use seed in evaluate.py

Former-commit-id: de95b6928293e79c7e204be307c1784ce146c1b1
This commit is contained in:
hiyouga 2023-11-06 18:17:51 +08:00
parent 2627f95ac3
commit 5c19786f7c

View File

@ -9,6 +9,7 @@ import fire
import json
import torch
import numpy as np
import transformers
from collections import Counter
from datasets import load_dataset
from dataclasses import dataclass
@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4,
save_name: Optional[str] = None
save_name: Optional[str] = None,
seed: Optional[int] = 42
):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,