mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
use seed in evaluate.py
Former-commit-id: de95b6928293e79c7e204be307c1784ce146c1b1
This commit is contained in:
parent
2627f95ac3
commit
5c19786f7c
@ -9,6 +9,7 @@ import fire
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import transformers
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -111,11 +112,13 @@ def evaluate(
|
|||||||
n_shot: Optional[int] = 5,
|
n_shot: Optional[int] = 5,
|
||||||
n_avg: Optional[int] = 1,
|
n_avg: Optional[int] = 1,
|
||||||
batch_size: Optional[int] = 4,
|
batch_size: Optional[int] = 4,
|
||||||
save_name: Optional[str] = None
|
save_name: Optional[str] = None,
|
||||||
|
seed: Optional[int] = 42
|
||||||
):
|
):
|
||||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
|
transformers.set_seed(seed)
|
||||||
chat_model = ChatModel(dict(
|
chat_model = ChatModel(dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user