This commit is contained in:
hiyouga
2023-10-20 23:28:52 +08:00
parent 0fcf66049d
commit b665e9e133
5 changed files with 44 additions and 48 deletions

View File

@@ -84,10 +84,12 @@ def batch_inference(
prefix_char: str
) -> List[str]:
logits = chat_model.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
probs = torch.nn.functional.softmax(
torch.stack(
[
logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
for choice in choices
],
dim=-1
@@ -120,8 +122,8 @@ def evaluate(
checkpoint_dir=checkpoint_dir,
template=template
))
chat_model.tokenizer.padding_side = "left" # avoid overflow issue in batched inference for llama2
eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]