mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
@@ -84,10 +84,12 @@ def batch_inference(
|
||||
prefix_char: str
|
||||
) -> List[str]:
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
[
|
||||
logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
||||
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
||||
for choice in choices
|
||||
],
|
||||
dim=-1
|
||||
@@ -120,8 +122,8 @@ def evaluate(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
template=template
|
||||
))
|
||||
chat_model.tokenizer.padding_side = "left" # avoid overflow issue in batched inference for llama2
|
||||
eval_template = eval_templates[lang]
|
||||
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
|
||||
|
||||
category_corrects: Dict[str, np.ndarray] = {
|
||||
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
@@ -289,8 +289,8 @@ register_template(
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
|
||||
"""
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
@@ -307,7 +307,6 @@ register_template(
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
"""
|
||||
register_template(
|
||||
name="alpaca",
|
||||
@@ -328,8 +327,8 @@ register_template(
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
||||
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
|
||||
https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"""
|
||||
register_template(
|
||||
name="vicuna",
|
||||
@@ -365,44 +364,9 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/CVI-SZU/Linly
|
||||
"""
|
||||
register_template(
|
||||
name="linly",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"User: {{query}}\nBot: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/Neutralzz/BiLLa
|
||||
"""
|
||||
register_template(
|
||||
name="billa",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
@@ -424,6 +388,8 @@ register_template(
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-34B
|
||||
"""
|
||||
register_template(
|
||||
name="aquila",
|
||||
@@ -449,6 +415,7 @@ register_template(
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
https://huggingface.co/internlm/internlm-chat-20b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
@@ -542,6 +509,7 @@ register_template(
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
https://huggingface.co/Qwen/Qwen-14B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="chatml",
|
||||
@@ -591,7 +559,29 @@ register_template(
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
|
||||
Supports: https://huggingface.co/openchat/openchat_v3.2_super
|
||||
"""
|
||||
register_template(
|
||||
name="openchat",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"GPT4 User: {{query}}",
|
||||
{"token": "<|end_of_turn|>"},
|
||||
"GPT4 Assistant: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end_of_turn|>"}
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
|
||||
https://huggingface.co/xverse/XVERSE-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="xverse",
|
||||
|
||||
@@ -113,6 +113,8 @@ class DataArguments:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
|
||||
Reference in New Issue
Block a user