use seed in evaluate.py

Former-commit-id: ab5cac1dfa681933f3266827f80068ce798b4c56
This commit is contained in:
hiyouga 2023-11-06 18:17:51 +08:00
parent ba3e8ba20c
commit 21ac46e439

View File

@ -9,6 +9,7 @@ import fire
import json
import torch
import numpy as np
import transformers
from collections import Counter
from datasets import load_dataset
from dataclasses import dataclass
@ -111,11 +112,13 @@ def evaluate(
n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4,
save_name: Optional[str] = None
save_name: Optional[str] = None,
seed: Optional[int] = 42
):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,