fix evaluator and cached_file in 4.31.0

Former-commit-id: ff6056405dea8e89a95fd3741fd309d3c7679896
This commit is contained in:
hiyouga 2023-11-18 19:39:23 +08:00
parent 6d8d8509da
commit 112108d564
2 changed files with 37 additions and 12 deletions

View File

@ -3,6 +3,7 @@
import os import os
import json import json
import torch import torch
import inspect
import tiktoken import tiktoken
import numpy as np import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -45,13 +46,18 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None: def eval(self) -> None:
if "token" in inspect.signature(cached_file).parameters:
kwargs = {"token": self.model_args.hf_hub_token}
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
mapping = cached_file( mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task), path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json", filename="mapping.json",
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token, **kwargs
revision=self.model_args.model_revision
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
@ -62,7 +68,9 @@ class Evaluator:
dataset = load_dataset( dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task), path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject, name=subject,
download_mode="force_redownload" cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []

View File

@ -1,4 +1,5 @@
import torch import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file from transformers.utils import cached_file
@ -94,19 +95,35 @@ def load_valuehead_params(
""" """
kwargs = { kwargs = {
"path_or_repo_id": path_or_repo_id, "path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir
"token": model_args.hf_hub_token
} }
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try: try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except: return torch.load(vhead_file, map_location="cpu")
try: except Exception as err:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu") try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
def prepare_model_for_training( def prepare_model_for_training(