Former-commit-id: eeab92323899694010469451b8dfb1f00d685bff
This commit is contained in:
hiyouga 2023-09-23 00:42:23 +08:00
parent 35d1921081
commit 28062c71b5

View File

@ -76,16 +76,18 @@ eval_templates = {
@torch.inference_mode()
def batch_inference(
chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor]
batch_input: Dict[str, torch.Tensor],
lang: Literal["zh", "en"]
) -> List[str]:
prefix_char = "\n" if lang == "zh" else " "
logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax(
torch.stack(
[
logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]]
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]]
],
dim=-1
),
@ -156,7 +158,7 @@ def evaluate(
return_attention_mask=True,
return_tensors="pt"
).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input)
preds = batch_inference(chat_model, batch_input, lang)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))