mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
fix MMLU
Former-commit-id: 2340b0d7df6b8b0f0e528bd0711a64ac42490566
This commit is contained in:
parent
5ee1bdecdc
commit
c714e445d8
@ -76,16 +76,18 @@ eval_templates = {
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(
|
def batch_inference(
|
||||||
chat_model: ChatModel,
|
chat_model: ChatModel,
|
||||||
batch_input: Dict[str, torch.Tensor]
|
batch_input: Dict[str, torch.Tensor],
|
||||||
|
lang: Literal["zh", "en"]
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
prefix_char = "\n" if lang == "zh" else " "
|
||||||
logits = chat_model.model(**batch_input).logits
|
logits = chat_model.model(**batch_input).logits
|
||||||
probs = torch.nn.functional.softmax(
|
probs = torch.nn.functional.softmax(
|
||||||
torch.stack(
|
torch.stack(
|
||||||
[
|
[
|
||||||
logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]],
|
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]],
|
||||||
logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]],
|
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]],
|
||||||
logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]],
|
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]],
|
||||||
logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]]
|
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]]
|
||||||
],
|
],
|
||||||
dim=-1
|
dim=-1
|
||||||
),
|
),
|
||||||
@ -156,7 +158,7 @@ def evaluate(
|
|||||||
return_attention_mask=True,
|
return_attention_mask=True,
|
||||||
return_tensors="pt"
|
return_tensors="pt"
|
||||||
).to(chat_model.model.device)
|
).to(chat_model.model.device)
|
||||||
preds = batch_inference(chat_model, batch_input)
|
preds = batch_inference(chat_model, batch_input, lang)
|
||||||
outputs += preds
|
outputs += preds
|
||||||
|
|
||||||
corrects = (np.array(outputs) == np.array(labels))
|
corrects = (np.array(outputs) == np.array(labels))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user